1 //---------------------------------------------------------------------
2 // <copyright file="DynamicUpdateCommand.cs" company="Microsoft">
3 //      Copyright (c) Microsoft Corporation.  All rights reserved.
4 // </copyright>
5 //
6 // @owner Microsoft
7 // @backupOwner Microsoft
8 //---------------------------------------------------------------------
9 
10 
11 using System.Collections.Generic;
12 using System.Data.Common.CommandTrees;
13 using System.Data.Metadata.Edm;
14 using System.Data.Common;
15 using System.Data.EntityClient;
16 using System.Diagnostics;
17 using System.Data.Common.Utils;
18 using System.Linq;
19 using System.Data.Common.CommandTrees.ExpressionBuilder;
20 using System.Data.Spatial;
21 
22 namespace System.Data.Mapping.Update.Internal
23 {
24     internal sealed class DynamicUpdateCommand : UpdateCommand
25     {
26         private readonly ModificationOperator m_operator;
27         private readonly TableChangeProcessor m_processor;
28         private readonly List<KeyValuePair<int, DbSetClause>> m_inputIdentifiers;
29         private readonly Dictionary<int, string> m_outputIdentifiers;
30         private readonly DbModificationCommandTree m_modificationCommandTree;
31 
32 
DynamicUpdateCommand(TableChangeProcessor processor, UpdateTranslator translator, ModificationOperator op, PropagatorResult originalValues, PropagatorResult currentValues, DbModificationCommandTree tree, Dictionary<int, string> outputIdentifiers)33         internal DynamicUpdateCommand(TableChangeProcessor processor, UpdateTranslator translator, ModificationOperator op,
34             PropagatorResult originalValues, PropagatorResult currentValues, DbModificationCommandTree tree,
35             Dictionary<int, string> outputIdentifiers)
36             : base(originalValues, currentValues)
37         {
38             m_processor = EntityUtil.CheckArgumentNull(processor, "processor");
39             m_operator = op;
40             m_modificationCommandTree = EntityUtil.CheckArgumentNull(tree, "commandTree");
41             m_outputIdentifiers = outputIdentifiers; // may be null (not all commands have output identifiers)
42 
43             // initialize identifier information (supports lateral propagation of server gen values)
44             if (ModificationOperator.Insert == op || ModificationOperator.Update == op)
45             {
46                 const int capacity = 2; // "average" number of identifiers per row
47                 m_inputIdentifiers = new List<KeyValuePair<int ,DbSetClause>>(capacity);
48 
49                 foreach (KeyValuePair<EdmMember, PropagatorResult> member in
50                     Helper.PairEnumerations(TypeHelpers.GetAllStructuralMembers(this.CurrentValues.StructuralType),
51                                              this.CurrentValues.GetMemberValues()))
52                 {
53                     DbSetClause setter;
54                     int identifier = member.Value.Identifier;
55 
56                     if (PropagatorResult.NullIdentifier != identifier &&
57                         TryGetSetterExpression(tree, member.Key, op, out setter)) // can find corresponding setter
58                     {
59                         foreach (int principal in translator.KeyManager.GetPrincipals(identifier))
60                         {
61                             m_inputIdentifiers.Add(new KeyValuePair<int, DbSetClause>(principal, setter));
62                         }
63                     }
64                 }
65             }
66         }
67 
68         // effects: try to find setter expression for the given member
69         // requires: command tree must be an insert or update tree (since other DML trees hnabve
TryGetSetterExpression(DbModificationCommandTree tree, EdmMember member, ModificationOperator op, out DbSetClause setter)70         private static bool TryGetSetterExpression(DbModificationCommandTree tree, EdmMember member, ModificationOperator op, out DbSetClause setter)
71         {
72             Debug.Assert(op == ModificationOperator.Insert || op == ModificationOperator.Update, "only inserts and updates have setters");
73             IEnumerable<DbModificationClause> clauses;
74             if (ModificationOperator.Insert == op)
75             {
76                 clauses = ((DbInsertCommandTree)tree).SetClauses;
77             }
78             else
79             {
80                 clauses = ((DbUpdateCommandTree)tree).SetClauses;
81             }
82             foreach (DbSetClause setClause in clauses)
83             {
84                 // check if this is the correct setter
85                 if (((DbPropertyExpression)setClause.Property).Property.EdmEquals(member))
86                 {
87                     setter = setClause;
88                     return true;
89                 }
90             }
91 
92             // no match found
93             setter = null;
94             return false;
95         }
96 
Execute(UpdateTranslator translator, EntityConnection connection, Dictionary<int, object> identifierValues, List<KeyValuePair<PropagatorResult, object>> generatedValues)97         internal override long Execute(UpdateTranslator translator, EntityConnection connection, Dictionary<int, object> identifierValues, List<KeyValuePair<PropagatorResult, object>> generatedValues)
98         {
99             // Compile command
100             using (DbCommand command = this.CreateCommand(translator, identifierValues))
101             {
102                 // configure command to use the connection and transaction for this session
103                 command.Transaction = ((null != connection.CurrentTransaction) ? connection.CurrentTransaction.StoreTransaction : null);
104                 command.Connection = connection.StoreConnection;
105                 if (translator.CommandTimeout.HasValue)
106                 {
107                     command.CommandTimeout = translator.CommandTimeout.Value;
108                 }
109 
110                 // Execute the query
111                 int rowsAffected;
112                 if (m_modificationCommandTree.HasReader)
113                 {
114                     // retrieve server gen results
115                     rowsAffected = 0;
116                     using (DbDataReader reader = command.ExecuteReader(CommandBehavior.SequentialAccess))
117                     {
118                         if (reader.Read())
119                         {
120                             rowsAffected++;
121 
122                             IBaseList<EdmMember> members = TypeHelpers.GetAllStructuralMembers(this.CurrentValues.StructuralType);
123 
124                             for (int ordinal = 0; ordinal < reader.FieldCount; ordinal++)
125                             {
126                                 // column name of result corresponds to column name of table
127                                 string columnName = reader.GetName(ordinal);
128                                 EdmMember member = members[columnName];
129                                 object value;
130                                 if (Helper.IsSpatialType(member.TypeUsage) && !reader.IsDBNull(ordinal))
131                                 {
132                                     value = SpatialHelpers.GetSpatialValue(translator.MetadataWorkspace, reader, member.TypeUsage, ordinal);
133                                 }
134                                 else
135                                 {
136                                     value = reader.GetValue(ordinal);
137                                 }
138 
139                                 // retrieve result which includes the context for back-propagation
140                                 int columnOrdinal = members.IndexOf(member);
141                                 PropagatorResult result = this.CurrentValues.GetMemberValue(columnOrdinal);
142 
143                                 // register for back-propagation
144                                 generatedValues.Add(new KeyValuePair<PropagatorResult, object>(result, value));
145 
146                                 // register identifier if it exists
147                                 int identifier = result.Identifier;
148                                 if (PropagatorResult.NullIdentifier != identifier)
149                                 {
150                                     identifierValues.Add(identifier, value);
151                                 }
152                             }
153                         }
154 
155                         // Consume the current reader (and subsequent result sets) so that any errors
156                         // executing the command can be intercepted
157                         CommandHelper.ConsumeReader(reader);
158                     }
159                 }
160                 else
161                 {
162                     rowsAffected = command.ExecuteNonQuery();
163                 }
164 
165                 return rowsAffected;
166             }
167         }
168 
169         /// <summary>
170         /// Gets DB command definition encapsulating store logic for this command.
171         /// </summary>
CreateCommand(UpdateTranslator translator, Dictionary<int, object> identifierValues)172         private DbCommand CreateCommand(UpdateTranslator translator, Dictionary<int, object> identifierValues)
173         {
174             DbModificationCommandTree commandTree = m_modificationCommandTree;
175 
176             // check if any server gen identifiers need to be set
177             if (null != m_inputIdentifiers)
178             {
179                 Dictionary<DbSetClause, DbSetClause> modifiedClauses = new Dictionary<DbSetClause, DbSetClause>();
180                 for (int idx = 0; idx < m_inputIdentifiers.Count; idx++)
181                 {
182                     KeyValuePair<int, DbSetClause> inputIdentifier = m_inputIdentifiers[idx];
183 
184                     object value;
185                     if (identifierValues.TryGetValue(inputIdentifier.Key, out value))
186                     {
187                         // reset the value of the identifier
188                         DbSetClause newClause = new DbSetClause(inputIdentifier.Value.Property, DbExpressionBuilder.Constant(value));
189                         modifiedClauses[inputIdentifier.Value] = newClause;
190                         m_inputIdentifiers[idx] = new KeyValuePair<int, DbSetClause>(inputIdentifier.Key, newClause);
191                     }
192                 }
193                 commandTree = RebuildCommandTree(commandTree, modifiedClauses);
194             }
195 
196             return translator.CreateCommand(commandTree);
197         }
198 
RebuildCommandTree(DbModificationCommandTree originalTree, Dictionary<DbSetClause, DbSetClause> clauseMappings)199         private DbModificationCommandTree RebuildCommandTree(DbModificationCommandTree originalTree, Dictionary<DbSetClause, DbSetClause> clauseMappings)
200         {
201             if (clauseMappings.Count == 0)
202             {
203                 return originalTree;
204             }
205 
206             DbModificationCommandTree result;
207             Debug.Assert(originalTree.CommandTreeKind == DbCommandTreeKind.Insert || originalTree.CommandTreeKind == DbCommandTreeKind.Update, "Set clauses specified for a modification tree that is not an update or insert tree?");
208             if (originalTree.CommandTreeKind == DbCommandTreeKind.Insert)
209             {
210                 DbInsertCommandTree insertTree = (DbInsertCommandTree)originalTree;
211                 result = new DbInsertCommandTree(insertTree.MetadataWorkspace, insertTree.DataSpace,
212                     insertTree.Target, ReplaceClauses(insertTree.SetClauses, clauseMappings).AsReadOnly(), insertTree.Returning);
213             }
214             else
215             {
216                 DbUpdateCommandTree updateTree = (DbUpdateCommandTree)originalTree;
217                 result = new DbUpdateCommandTree(updateTree.MetadataWorkspace, updateTree.DataSpace,
218                     updateTree.Target, updateTree.Predicate, ReplaceClauses(updateTree.SetClauses, clauseMappings).AsReadOnly(), updateTree.Returning);
219             }
220 
221             return result;
222         }
223 
224         /// <summary>
225         /// Creates a new list of modification clauses with the specified remapped clauses replaced.
226         /// </summary>
ReplaceClauses(IList<DbModificationClause> originalClauses, Dictionary<DbSetClause, DbSetClause> mappings)227         private List<DbModificationClause> ReplaceClauses(IList<DbModificationClause> originalClauses, Dictionary<DbSetClause, DbSetClause> mappings)
228         {
229             List<DbModificationClause> result = new List<DbModificationClause>(originalClauses.Count);
230             for (int idx = 0; idx < originalClauses.Count; idx++)
231             {
232                 DbSetClause replacementClause;
233                 if (mappings.TryGetValue((DbSetClause)originalClauses[idx], out replacementClause))
234                 {
235                     result.Add(replacementClause);
236                 }
237                 else
238                 {
239                     result.Add(originalClauses[idx]);
240                 }
241             }
242             return result;
243         }
244 
245         internal ModificationOperator Operator { get { return m_operator; } }
246 
247         internal override EntitySet Table { get { return this.m_processor.Table; } }
248 
249         internal override IEnumerable<int> InputIdentifiers
250         {
251             get
252             {
253                 if (null == m_inputIdentifiers)
254                 {
255                     yield break;
256                 }
257                 else
258                 {
259                     foreach (KeyValuePair<int, DbSetClause> inputIdentifier in m_inputIdentifiers)
260                     {
261                         yield return inputIdentifier.Key;
262                     }
263                 }
264             }
265         }
266 
267         internal override IEnumerable<int> OutputIdentifiers
268         {
269             get
270             {
271                 if (null == m_outputIdentifiers)
272                 {
273                     return Enumerable.Empty<int>();
274                 }
275                 return m_outputIdentifiers.Keys;
276             }
277         }
278 
279         internal override UpdateCommandKind Kind
280         {
281             get { return UpdateCommandKind.Dynamic; }
282         }
283 
GetStateEntries(UpdateTranslator translator)284         internal override IList<IEntityStateEntry> GetStateEntries(UpdateTranslator translator)
285         {
286             List<IEntityStateEntry> stateEntries = new List<IEntityStateEntry>(2);
287             if (null != this.OriginalValues)
288             {
289                 foreach (IEntityStateEntry stateEntry in SourceInterpreter.GetAllStateEntries(
290                     this.OriginalValues, translator, this.Table))
291                 {
292                     stateEntries.Add(stateEntry);
293                 }
294             }
295 
296             if (null != this.CurrentValues)
297             {
298                 foreach (IEntityStateEntry stateEntry in SourceInterpreter.GetAllStateEntries(
299                     this.CurrentValues, translator, this.Table))
300                 {
301                     stateEntries.Add(stateEntry);
302                 }
303             }
304             return stateEntries;
305         }
306 
CompareToType(UpdateCommand otherCommand)307         internal override int CompareToType(UpdateCommand otherCommand)
308         {
309             Debug.Assert(!object.ReferenceEquals(this, otherCommand), "caller is supposed to ensure otherCommand is different reference");
310 
311             DynamicUpdateCommand other = (DynamicUpdateCommand)otherCommand;
312 
313             // order by operation type
314             int result = (int)this.Operator - (int)other.Operator;
315             if (0 != result) { return result; }
316 
317             // order by Container.Table
318             result = StringComparer.Ordinal.Compare(this.m_processor.Table.Name, other.m_processor.Table.Name);
319             if (0 != result) { return result; }
320             result = StringComparer.Ordinal.Compare(this.m_processor.Table.EntityContainer.Name, other.m_processor.Table.EntityContainer.Name);
321             if (0 != result) { return result; }
322 
323             // order by table key
324             PropagatorResult thisResult = (this.Operator == ModificationOperator.Delete ? this.OriginalValues : this.CurrentValues);
325             PropagatorResult otherResult = (other.Operator == ModificationOperator.Delete ? other.OriginalValues : other.CurrentValues);
326             for (int i = 0; i < m_processor.KeyOrdinals.Length; i++)
327             {
328                 int keyOrdinal = m_processor.KeyOrdinals[i];
329                 object thisValue = thisResult.GetMemberValue(keyOrdinal).GetSimpleValue();
330                 object otherValue = otherResult.GetMemberValue(keyOrdinal).GetSimpleValue();
331                 result = ByValueComparer.Default.Compare(thisValue, otherValue);
332                 if (0 != result) { return result; }
333             }
334 
335             // If the result is still zero, it means key values are all the same. Switch to synthetic identifiers
336             // to differentiate.
337             for (int i = 0; i < m_processor.KeyOrdinals.Length; i++)
338             {
339                 int keyOrdinal = m_processor.KeyOrdinals[i];
340                 int thisValue = thisResult.GetMemberValue(keyOrdinal).Identifier;
341                 int otherValue = otherResult.GetMemberValue(keyOrdinal).Identifier;
342                 result = thisValue - otherValue;
343                 if (0 != result) { return result; }
344             }
345 
346             return result;
347         }
348     }
349 }
350