1 //--------------------------------------------------------------------- 2 // <copyright file="MethodCallTranslator.cs" company="Microsoft"> 3 // Copyright (c) Microsoft Corporation. All rights reserved. 4 // </copyright> 5 // 6 // @owner Microsoft, Microsoft 7 //--------------------------------------------------------------------- 8 9 namespace System.Data.Objects.ELinq 10 { 11 using System.Collections.Generic; 12 using System.Data.Common; 13 using System.Data.Common.CommandTrees; 14 using System.Data.Common.CommandTrees.ExpressionBuilder; 15 using System.Data.Entity; 16 using System.Data.Metadata.Edm; 17 using System.Data.Objects.DataClasses; 18 using System.Diagnostics; 19 using System.Linq; 20 using System.Linq.Expressions; 21 using System.Reflection; 22 using CqtExpression = System.Data.Common.CommandTrees.DbExpression; 23 using LinqExpression = System.Linq.Expressions.Expression; 24 25 internal sealed partial class ExpressionConverter 26 { 27 /// <summary> 28 /// Translates System.Linq.Expression.MethodCallExpression to System.Data.Common.CommandTrees.DbExpression 29 /// </summary> 30 private sealed partial class MethodCallTranslator : TypedTranslator<MethodCallExpression> 31 { MethodCallTranslator()32 internal MethodCallTranslator() 33 : base(ExpressionType.Call) { } TypedTranslate(ExpressionConverter parent, MethodCallExpression linq)34 protected override CqtExpression TypedTranslate(ExpressionConverter parent, MethodCallExpression linq) 35 { 36 // check if this is a known sequence method 37 SequenceMethod sequenceMethod; 38 SequenceMethodTranslator sequenceTranslator; 39 if (ReflectionUtil.TryIdentifySequenceMethod(linq.Method, out sequenceMethod) && 40 s_sequenceTranslators.TryGetValue(sequenceMethod, out sequenceTranslator)) 41 { 42 return sequenceTranslator.Translate(parent, linq, sequenceMethod); 43 } 44 // check if this is a known method 45 CallTranslator callTranslator; 46 if (TryGetCallTranslator(linq.Method, out callTranslator)) 47 { 48 return callTranslator.Translate(parent, linq); 49 } 50 51 // check if this is an ObjectQuery<> builder method 52 if (ObjectQueryCallTranslator.IsCandidateMethod(linq.Method)) 53 { 54 ObjectQueryCallTranslator builderTranslator; 55 if (s_objectQueryTranslators.TryGetValue(linq.Method.Name, out builderTranslator)) 56 { 57 return builderTranslator.Translate(parent, linq); 58 } 59 } 60 61 // check if this method has the FunctionAttribute (known proxy) 62 EdmFunctionAttribute functionAttribute = linq.Method.GetCustomAttributes(typeof(EdmFunctionAttribute), false) 63 .Cast<EdmFunctionAttribute>().FirstOrDefault(); 64 if (null != functionAttribute) 65 { 66 return s_functionCallTranslator.TranslateFunctionCall(parent, linq, functionAttribute); 67 } 68 69 switch(linq.Method.Name) 70 { 71 case "Contains": 72 { 73 if (linq.Method.GetParameters().Count() == 1 && linq.Method.ReturnType.Equals(typeof(bool))) 74 { 75 Type[] genericArguments; 76 if (linq.Method.IsImplementationOfGenericInterfaceMethod(typeof(ICollection<>), out genericArguments)) 77 { 78 return ContainsTranslator.TranslateContains(parent, linq.Object, linq.Arguments[0]); 79 } 80 } 81 break; 82 } 83 } 84 85 // fall back on the default translator 86 return s_defaultTranslator.Translate(parent, linq); 87 } 88 89 #region Static members and initializers 90 private const string s_stringsTypeFullName = "Microsoft.VisualBasic.Strings"; 91 92 // initialize fall-back translator 93 private static readonly CallTranslator s_defaultTranslator = new DefaultTranslator(); 94 private static readonly FunctionCallTranslator s_functionCallTranslator = new FunctionCallTranslator(); 95 private static readonly Dictionary<MethodInfo, CallTranslator> s_methodTranslators = InitializeMethodTranslators(); 96 private static readonly Dictionary<SequenceMethod, SequenceMethodTranslator> s_sequenceTranslators = InitializeSequenceMethodTranslators(); 97 private static readonly Dictionary<string, ObjectQueryCallTranslator> s_objectQueryTranslators = InitializeObjectQueryTranslators(); 98 private static bool s_vbMethodsInitialized; 99 private static readonly object s_vbInitializerLock = new object(); 100 InitializeMethodTranslators()101 private static Dictionary<MethodInfo, CallTranslator> InitializeMethodTranslators() 102 { 103 // initialize translators for specific methods (e.g., Int32.op_Equality) 104 Dictionary<MethodInfo, CallTranslator> methodTranslators = new Dictionary<MethodInfo, CallTranslator>(); 105 foreach (CallTranslator translator in GetCallTranslators()) 106 { 107 foreach (MethodInfo method in translator.Methods) 108 { 109 methodTranslators.Add(method, translator); 110 } 111 } 112 113 return methodTranslators; 114 } 115 InitializeSequenceMethodTranslators()116 private static Dictionary<SequenceMethod, SequenceMethodTranslator> InitializeSequenceMethodTranslators() 117 { 118 // initialize translators for sequence methods (e.g., Sequence.Select) 119 Dictionary<SequenceMethod, SequenceMethodTranslator> sequenceTranslators = new Dictionary<SequenceMethod, SequenceMethodTranslator>(); 120 foreach (SequenceMethodTranslator translator in GetSequenceMethodTranslators()) 121 { 122 foreach (SequenceMethod method in translator.Methods) 123 { 124 sequenceTranslators.Add(method, translator); 125 } 126 } 127 128 return sequenceTranslators; 129 } 130 InitializeObjectQueryTranslators()131 private static Dictionary<string, ObjectQueryCallTranslator> InitializeObjectQueryTranslators() 132 { 133 // initialize translators for object query methods (e.g. ObjectQuery<T>.OfType<S>(), ObjectQuery<T>.Include(string) ) 134 Dictionary<string, ObjectQueryCallTranslator> objectQueryCallTranslators = new Dictionary<string, ObjectQueryCallTranslator>(StringComparer.Ordinal); 135 foreach (ObjectQueryCallTranslator translator in GetObjectQueryCallTranslators()) 136 { 137 objectQueryCallTranslators[translator.MethodName] = translator; 138 } 139 140 return objectQueryCallTranslators; 141 } 142 143 /// <summary> 144 /// Tries to get a translator for the given method info. 145 /// If the given method info corresponds to a Visual Basic property, 146 /// it also initializes the Visual Basic translators if they have not been initialized 147 /// </summary> 148 /// <param name="methodInfo"></param> 149 /// <param name="callTranslator"></param> 150 /// <returns></returns> TryGetCallTranslator(MethodInfo methodInfo, out CallTranslator callTranslator)151 private static bool TryGetCallTranslator(MethodInfo methodInfo, out CallTranslator callTranslator) 152 { 153 if (s_methodTranslators.TryGetValue(methodInfo, out callTranslator)) 154 { 155 return true; 156 } 157 // check if this is the visual basic assembly 158 if (s_visualBasicAssemblyFullName == methodInfo.DeclaringType.Assembly.FullName) 159 { 160 lock (s_vbInitializerLock) 161 { 162 if (!s_vbMethodsInitialized) 163 { 164 InitializeVBMethods(methodInfo.DeclaringType.Assembly); 165 s_vbMethodsInitialized = true; 166 } 167 // try again 168 return s_methodTranslators.TryGetValue(methodInfo, out callTranslator); 169 } 170 } 171 172 callTranslator = null; 173 return false; 174 } 175 InitializeVBMethods(Assembly vbAssembly)176 private static void InitializeVBMethods(Assembly vbAssembly) 177 { 178 Debug.Assert(!s_vbMethodsInitialized); 179 foreach (CallTranslator translator in GetVisualBasicCallTranslators(vbAssembly)) 180 { 181 foreach (MethodInfo method in translator.Methods) 182 { 183 s_methodTranslators.Add(method, translator); 184 } 185 } 186 } 187 GetVisualBasicCallTranslators(Assembly vbAssembly)188 private static IEnumerable<CallTranslator> GetVisualBasicCallTranslators(Assembly vbAssembly) 189 { 190 yield return new VBCanonicalFunctionDefaultTranslator(vbAssembly); 191 yield return new VBCanonicalFunctionRenameTranslator(vbAssembly); 192 yield return new VBDatePartTranslator(vbAssembly); 193 } 194 GetCallTranslators()195 private static IEnumerable<CallTranslator> GetCallTranslators() 196 { 197 return new CallTranslator[] 198 { 199 new CanonicalFunctionDefaultTranslator(), 200 new AsUnicodeFunctionTranslator(), 201 new AsNonUnicodeFunctionTranslator(), 202 new MathPowerTranslator(), 203 new GuidNewGuidTranslator(), 204 new StringContainsTranslator(), 205 new StartsWithTranslator(), 206 new EndsWithTranslator(), 207 new IndexOfTranslator(), 208 new SubstringTranslator(), 209 new RemoveTranslator(), 210 new InsertTranslator(), 211 new IsNullOrEmptyTranslator(), 212 new StringConcatTranslator(), 213 new TrimTranslator(), 214 new TrimStartTranslator(), 215 new TrimEndTranslator(), 216 new SpatialMethodCallTranslator(), 217 }; 218 } 219 GetSequenceMethodTranslators()220 private static IEnumerable<SequenceMethodTranslator> GetSequenceMethodTranslators() 221 { 222 yield return new ConcatTranslator(); 223 yield return new UnionTranslator(); 224 yield return new IntersectTranslator(); 225 yield return new ExceptTranslator(); 226 yield return new DistinctTranslator(); 227 yield return new WhereTranslator(); 228 yield return new SelectTranslator(); 229 yield return new OrderByTranslator(); 230 yield return new OrderByDescendingTranslator(); 231 yield return new ThenByTranslator(); 232 yield return new ThenByDescendingTranslator(); 233 yield return new SelectManyTranslator(); 234 yield return new AnyTranslator(); 235 yield return new AnyPredicateTranslator(); 236 yield return new AllTranslator(); 237 yield return new JoinTranslator(); 238 yield return new GroupByTranslator(); 239 yield return new MaxTranslator(); 240 yield return new MinTranslator(); 241 yield return new AverageTranslator(); 242 yield return new SumTranslator(); 243 yield return new CountTranslator(); 244 yield return new LongCountTranslator(); 245 yield return new CastMethodTranslator(); 246 yield return new GroupJoinTranslator(); 247 yield return new OfTypeTranslator(); 248 yield return new PassthroughTranslator(); 249 yield return new DefaultIfEmptyTranslator(); 250 yield return new FirstTranslator(); 251 yield return new FirstPredicateTranslator(); 252 yield return new FirstOrDefaultTranslator(); 253 yield return new FirstOrDefaultPredicateTranslator(); 254 yield return new TakeTranslator(); 255 yield return new SkipTranslator(); 256 yield return new SingleTranslator(); 257 yield return new SinglePredicateTranslator(); 258 yield return new SingleOrDefaultTranslator(); 259 yield return new SingleOrDefaultPredicateTranslator(); 260 yield return new ContainsTranslator(); 261 } 262 GetObjectQueryCallTranslators()263 private static IEnumerable<ObjectQueryCallTranslator> GetObjectQueryCallTranslators() 264 { 265 yield return new ObjectQueryBuilderDistinctTranslator(); 266 yield return new ObjectQueryBuilderExceptTranslator(); 267 yield return new ObjectQueryBuilderFirstTranslator(); 268 yield return new ObjectQueryIncludeTranslator(); 269 yield return new ObjectQueryBuilderIntersectTranslator(); 270 yield return new ObjectQueryBuilderOfTypeTranslator(); 271 yield return new ObjectQueryBuilderUnionTranslator(); 272 yield return new ObjectQueryMergeAsTranslator(); 273 yield return new ObjectQueryIncludeSpanTranslator(); 274 } 275 IsTrivialRename( LambdaExpression selectorLambda, ExpressionConverter converter, out string leftName, out string rightName, out InitializerMetadata initializerMetadata)276 private static bool IsTrivialRename( 277 LambdaExpression selectorLambda, 278 ExpressionConverter converter, 279 out string leftName, 280 out string rightName, 281 out InitializerMetadata initializerMetadata) 282 { 283 leftName = null; 284 rightName = null; 285 initializerMetadata = null; 286 287 if (selectorLambda.Parameters.Count != 2 || 288 selectorLambda.Body.NodeType != ExpressionType.New) 289 { 290 return false; 291 } 292 293 var newExpression = (NewExpression)selectorLambda.Body; 294 295 if (newExpression.Arguments.Count != 2) 296 { 297 return false; 298 } 299 300 if (newExpression.Arguments[0] != selectorLambda.Parameters[0] || 301 newExpression.Arguments[1] != selectorLambda.Parameters[1]) 302 { 303 return false; 304 } 305 306 leftName = newExpression.Members[0].Name; 307 rightName = newExpression.Members[1].Name; 308 309 // Construct a new initializer type in metadata for the renaming projection (provides the 310 // necessary context for the object materializer) 311 initializerMetadata = InitializerMetadata.CreateProjectionInitializer(converter.EdmItemCollection, newExpression); 312 converter.ValidateInitializerMetadata(initializerMetadata); 313 314 return true; 315 } 316 #endregion 317 318 #region Method translators 319 private abstract class CallTranslator 320 { 321 private readonly IEnumerable<MethodInfo> _methods; CallTranslator(params MethodInfo[] methods)322 protected CallTranslator(params MethodInfo[] methods) { _methods = methods; } CallTranslator(IEnumerable<MethodInfo> methods)323 protected CallTranslator(IEnumerable<MethodInfo> methods) { _methods = methods; } 324 internal IEnumerable<MethodInfo> Methods { get { return _methods; } } Translate(ExpressionConverter parent, MethodCallExpression call)325 internal abstract CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call); ToString()326 public override string ToString() 327 { 328 return GetType().Name; 329 } 330 } 331 private abstract class ObjectQueryCallTranslator : CallTranslator 332 { IsCandidateMethod(MethodInfo method)333 internal static bool IsCandidateMethod(MethodInfo method) 334 { 335 Type declaringType = method.DeclaringType; 336 return ((method.IsPublic || (method.IsAssembly && (method.Name == "MergeAs" || method.Name == "IncludeSpan"))) && 337 null != declaringType && 338 declaringType.IsGenericType && 339 typeof(ObjectQuery<>) == declaringType.GetGenericTypeDefinition()); 340 } 341 RemoveConvertToObjectQuery(LinqExpression queryExpression)342 internal static LinqExpression RemoveConvertToObjectQuery(LinqExpression queryExpression) 343 { 344 // Remove the Convert(ObjectQuery<T>) that was placed around the LINQ expression that defines an ObjectQuery to allow it to be used as the argument in a call to MergeAs or IncludeSpan 345 if (queryExpression.NodeType == ExpressionType.Convert) 346 { 347 UnaryExpression convertExpression = (UnaryExpression)queryExpression; 348 Type argumentType = convertExpression.Operand.Type; 349 if (argumentType.IsGenericType && (typeof(IQueryable<>) == argumentType.GetGenericTypeDefinition() || typeof(IOrderedQueryable<>) == argumentType.GetGenericTypeDefinition())) 350 { 351 Debug.Assert(convertExpression.Type.IsGenericType && typeof(ObjectQuery<>) == convertExpression.Type.GetGenericTypeDefinition(), "MethodCall with internal MergeAs/IncludeSpan method was not constructed by LINQ to Entities?"); 352 queryExpression = convertExpression.Operand; 353 } 354 } 355 356 return queryExpression; 357 } 358 359 private readonly string _methodName; 360 ObjectQueryCallTranslator(string methodName)361 protected ObjectQueryCallTranslator(string methodName) 362 { 363 _methodName = methodName; 364 } 365 366 internal string MethodName { get { return _methodName; } } 367 } 368 private abstract class ObjectQueryBuilderCallTranslator : ObjectQueryCallTranslator 369 { 370 private readonly SequenceMethodTranslator _translator; 371 ObjectQueryBuilderCallTranslator(string methodName, SequenceMethod sequenceEquivalent)372 protected ObjectQueryBuilderCallTranslator(string methodName, SequenceMethod sequenceEquivalent) 373 : base(methodName) 374 { 375 bool translatorFound = s_sequenceTranslators.TryGetValue(sequenceEquivalent, out _translator); 376 Debug.Assert(translatorFound, "Translator not found for " + sequenceEquivalent.ToString()); 377 } 378 Translate(ExpressionConverter parent, MethodCallExpression call)379 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 380 { 381 return _translator.Translate(parent, call); 382 } 383 } 384 private sealed class ObjectQueryBuilderUnionTranslator : ObjectQueryBuilderCallTranslator 385 { ObjectQueryBuilderUnionTranslator()386 internal ObjectQueryBuilderUnionTranslator() 387 : base("Union", SequenceMethod.Union) 388 { 389 } 390 } 391 private sealed class ObjectQueryBuilderIntersectTranslator : ObjectQueryBuilderCallTranslator 392 { ObjectQueryBuilderIntersectTranslator()393 internal ObjectQueryBuilderIntersectTranslator() 394 : base("Intersect", SequenceMethod.Intersect) 395 { 396 } 397 } 398 private sealed class ObjectQueryBuilderExceptTranslator : ObjectQueryBuilderCallTranslator 399 { ObjectQueryBuilderExceptTranslator()400 internal ObjectQueryBuilderExceptTranslator() 401 : base("Except", SequenceMethod.Except) 402 { 403 } 404 } 405 private sealed class ObjectQueryBuilderDistinctTranslator : ObjectQueryBuilderCallTranslator 406 { ObjectQueryBuilderDistinctTranslator()407 internal ObjectQueryBuilderDistinctTranslator() 408 : base("Distinct", SequenceMethod.Distinct) 409 { 410 } 411 } 412 private sealed class ObjectQueryBuilderOfTypeTranslator : ObjectQueryBuilderCallTranslator 413 { ObjectQueryBuilderOfTypeTranslator()414 internal ObjectQueryBuilderOfTypeTranslator() 415 : base("OfType", SequenceMethod.OfType) 416 { 417 } 418 } 419 private sealed class ObjectQueryBuilderFirstTranslator : ObjectQueryBuilderCallTranslator 420 { ObjectQueryBuilderFirstTranslator()421 internal ObjectQueryBuilderFirstTranslator() 422 : base("First", SequenceMethod.First) 423 { 424 } 425 } 426 private sealed class ObjectQueryIncludeTranslator : ObjectQueryCallTranslator 427 { ObjectQueryIncludeTranslator()428 internal ObjectQueryIncludeTranslator() 429 : base("Include") 430 { 431 } Translate(ExpressionConverter parent, MethodCallExpression call)432 internal override DbExpression Translate(ExpressionConverter parent, MethodCallExpression call) 433 { 434 Debug.Assert(call.Object != null && call.Arguments.Count == 1 && call.Arguments[0] != null && call.Arguments[0].Type.Equals(typeof(string)), "Invalid Include arguments?"); 435 CqtExpression queryExpression = parent.TranslateExpression(call.Object); 436 Span span; 437 if (!parent.TryGetSpan(queryExpression, out span)) 438 { 439 span = null; 440 } 441 CqtExpression arg = parent.TranslateExpression(call.Arguments[0]); 442 string includePath = null; 443 if (arg.ExpressionKind == DbExpressionKind.Constant) 444 { 445 includePath = (string)((DbConstantExpression)arg).Value; 446 } 447 else 448 { 449 // The 'Include' method implementation on ELinqQueryState creates 450 // a method call expression with a string constant argument taking 451 // the value of the string argument passed to ObjectQuery.Include, 452 // and so this is the only supported pattern here. 453 throw EntityUtil.NotSupported(Strings.ELinq_UnsupportedInclude); 454 } 455 if (parent.CanIncludeSpanInfo()) 456 { 457 span = Span.IncludeIn(span, includePath); 458 } 459 return parent.AddSpanMapping(queryExpression, span); 460 } 461 } 462 private sealed class ObjectQueryMergeAsTranslator : ObjectQueryCallTranslator 463 { ObjectQueryMergeAsTranslator()464 internal ObjectQueryMergeAsTranslator() 465 : base("MergeAs") 466 { 467 } Translate(ExpressionConverter parent, MethodCallExpression call)468 internal override DbExpression Translate(ExpressionConverter parent, MethodCallExpression call) 469 { 470 Debug.Assert(call.Object != null && call.Arguments.Count == 1 && call.Arguments[0] != null && call.Arguments[0].Type.Equals(typeof(MergeOption)), "Invalid MergeAs arguments?"); 471 472 // Note that the MergeOption must be inspected and applied BEFORE visiting the argument, 473 // so that it is 'locked down' before a sub-query with a user-specified merge option is encountered. 474 if (call.Arguments[0].NodeType != ExpressionType.Constant) 475 { 476 // The 'MergeAs' method implementation on ObjectQuery<T> creates 477 // a method call expression with a MergeOption constant argument taking 478 // the value of the merge option argument passed to ObjectQuery.MergeAs, 479 // and so this is the only supported pattern here. 480 throw EntityUtil.NotSupported(Strings.ELinq_UnsupportedMergeAs); 481 } 482 483 MergeOption mergeAsOption = (MergeOption)((ConstantExpression)call.Arguments[0]).Value; 484 EntityUtil.CheckArgumentMergeOption(mergeAsOption); 485 parent.NotifyMergeOption(mergeAsOption); 486 487 LinqExpression inputQuery = RemoveConvertToObjectQuery(call.Object); 488 CqtExpression queryExpression = parent.TranslateExpression(inputQuery); 489 Span span; 490 if (!parent.TryGetSpan(queryExpression, out span)) 491 { 492 span = null; 493 } 494 495 return parent.AddSpanMapping(queryExpression, span); 496 } 497 } 498 private sealed class ObjectQueryIncludeSpanTranslator : ObjectQueryCallTranslator 499 { ObjectQueryIncludeSpanTranslator()500 internal ObjectQueryIncludeSpanTranslator() 501 : base("IncludeSpan") 502 { 503 } Translate(ExpressionConverter parent, MethodCallExpression call)504 internal override DbExpression Translate(ExpressionConverter parent, MethodCallExpression call) 505 { 506 Debug.Assert(call.Object != null && call.Arguments.Count == 1 && call.Arguments[0] != null && call.Arguments[0].Type.Equals(typeof(Span)), "Invalid IncludeSpan arguments?"); 507 Debug.Assert(call.Arguments[0].NodeType == ExpressionType.Constant, "Whenever an IncludeSpan MethodCall is inlined, the argument must be a constant"); 508 Span span = (Span)((ConstantExpression)call.Arguments[0]).Value; 509 LinqExpression inputQuery = RemoveConvertToObjectQuery(call.Object); 510 DbExpression queryExpression = parent.TranslateExpression(inputQuery); 511 if (!(parent.CanIncludeSpanInfo())) 512 { 513 span = null; 514 } 515 return parent.AddSpanMapping(queryExpression, span); 516 } 517 } 518 private sealed class DefaultTranslator : CallTranslator 519 { DefaultTranslator()520 internal DefaultTranslator() : base() { } Translate(ExpressionConverter parent, MethodCallExpression call)521 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 522 { 523 MethodInfo suggestedMethodInfo; 524 if (TryGetAlternativeMethod(call.Method, out suggestedMethodInfo)) 525 { 526 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedMethodSuggestedAlternative(call.Method, suggestedMethodInfo)); 527 } 528 //The default error message 529 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedMethod(call.Method)); 530 } 531 532 #region Static Members 533 private static readonly Dictionary<MethodInfo, MethodInfo> s_alternativeMethods = InitializeAlternateMethodInfos(); 534 private static bool s_vbMethodsInitialized; 535 private static readonly object s_vbInitializerLock = new object(); 536 537 /// <summary> 538 /// Tries to check whether there is an alternative method suggested insted of the given unsupported one. 539 /// </summary> 540 /// <param name="originalMethodInfo"></param> 541 /// <param name="suggestedMethodInfo"></param> 542 /// <returns></returns> TryGetAlternativeMethod(MethodInfo originalMethodInfo, out MethodInfo suggestedMethodInfo)543 private static bool TryGetAlternativeMethod(MethodInfo originalMethodInfo, out MethodInfo suggestedMethodInfo) 544 { 545 if (s_alternativeMethods.TryGetValue(originalMethodInfo, out suggestedMethodInfo)) 546 { 547 return true; 548 } 549 // check if this is the visual basic assembly 550 if (s_visualBasicAssemblyFullName == originalMethodInfo.DeclaringType.Assembly.FullName) 551 { 552 lock (s_vbInitializerLock) 553 { 554 if (!s_vbMethodsInitialized) 555 { 556 InitializeVBMethods(originalMethodInfo.DeclaringType.Assembly); 557 s_vbMethodsInitialized = true; 558 } 559 // try again 560 return s_alternativeMethods.TryGetValue(originalMethodInfo, out suggestedMethodInfo); 561 } 562 } 563 suggestedMethodInfo = null; 564 return false; 565 } 566 567 /// <summary> 568 /// Initializes the dictionary of alternative methods. 569 /// Currently, it simply initializes an empty dictionary. 570 /// </summary> 571 /// <returns></returns> InitializeAlternateMethodInfos()572 private static Dictionary<MethodInfo, MethodInfo> InitializeAlternateMethodInfos() 573 { 574 return new Dictionary<MethodInfo, MethodInfo>(1); 575 } 576 577 /// <summary> 578 /// Populates the dictionary of alternative methods with the VB methods 579 /// </summary> 580 /// <param name="vbAssembly"></param> InitializeVBMethods(Assembly vbAssembly)581 private static void InitializeVBMethods(Assembly vbAssembly) 582 { 583 Debug.Assert(!s_vbMethodsInitialized); 584 585 //Handle { Mid(arg1, ar2), Mid(arg1, arg2, arg3) } 586 Type stringsType = vbAssembly.GetType(s_stringsTypeFullName); 587 588 s_alternativeMethods.Add( 589 stringsType.GetMethod("Mid", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int) }, null), 590 stringsType.GetMethod("Mid", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int), typeof(int) }, null)); 591 } 592 #endregion 593 } 594 595 private sealed class FunctionCallTranslator 596 { FunctionCallTranslator()597 internal FunctionCallTranslator() { } 598 TranslateFunctionCall(ExpressionConverter parent, MethodCallExpression call, EdmFunctionAttribute functionAttribute)599 internal DbExpression TranslateFunctionCall(ExpressionConverter parent, MethodCallExpression call, EdmFunctionAttribute functionAttribute) 600 { 601 //Validate that the attribute parameters are not null or empty 602 FunctionCallTranslator.ValidateFunctionAttributeParameter(call, functionAttribute.NamespaceName, "namespaceName"); 603 FunctionCallTranslator.ValidateFunctionAttributeParameter(call, functionAttribute.FunctionName, "functionName"); 604 605 // Translate the inputs 606 var arguments = call.Arguments.Select(a => UnwrapNoOpConverts(a)).Select(b => NormalizeAllSetSources(parent, parent.TranslateExpression(b))).ToList(); 607 var argumentTypes = arguments.Select(a => a.ResultType).ToList(); 608 609 //Resolve the function 610 EdmFunction function = parent.FindFunction(functionAttribute.NamespaceName, functionAttribute.FunctionName, argumentTypes, false, call); 611 612 if (!function.IsComposableAttribute) 613 { 614 throw EntityUtil.NotSupported(System.Data.Entity.Strings.CannotCallNoncomposableFunction(function.FullName)); 615 } 616 617 DbExpression result = function.Invoke(arguments); 618 619 return ValidateReturnType(result, result.ResultType, parent, call, call.Type, false); 620 } 621 622 /// <summary> 623 /// Recursively rewrite the argument expression to unwrap any "structured" set sources 624 /// using ExpressionCoverter.NormalizeSetSource(). This is currently required for IGrouping 625 /// and EntityCollection as argument types to functions. 626 /// NOTE: Changes made to this function might have to be applied to ExpressionCoverter.NormalizeSetSource() too. 627 /// </summary> 628 /// <param name="parent"></param> 629 /// <param name="argumentExpr"></param> 630 /// <returns></returns> NormalizeAllSetSources(ExpressionConverter parent, DbExpression argumentExpr)631 private DbExpression NormalizeAllSetSources(ExpressionConverter parent, DbExpression argumentExpr) 632 { 633 DbExpression newExpr = null; 634 BuiltInTypeKind type = argumentExpr.ResultType.EdmType.BuiltInTypeKind; 635 636 switch(type) 637 { 638 case BuiltInTypeKind.CollectionType: 639 { 640 DbExpressionBinding bindingExpr = DbExpressionBuilder.BindAs(argumentExpr, parent.AliasGenerator.Next()); 641 DbExpression normalizedExpr = NormalizeAllSetSources(parent, bindingExpr.Variable); 642 if (normalizedExpr != bindingExpr.Variable) 643 { 644 newExpr = DbExpressionBuilder.Project(bindingExpr, normalizedExpr); 645 } 646 break; 647 } 648 case BuiltInTypeKind.RowType: 649 { 650 List<KeyValuePair<string, DbExpression>> newColumns = new List<KeyValuePair<string, DbExpression>>(); 651 RowType rowType = argumentExpr.ResultType.EdmType as RowType; 652 bool isAnyPropertyChanged = false; 653 654 foreach (EdmProperty recColumn in rowType.Properties) 655 { 656 DbPropertyExpression propertyExpr = argumentExpr.Property(recColumn); 657 newExpr = NormalizeAllSetSources(parent, propertyExpr); 658 if (newExpr != propertyExpr) 659 { 660 isAnyPropertyChanged = true; 661 newColumns.Add(new KeyValuePair<string, DbExpression>(propertyExpr.Property.Name, newExpr)); 662 } 663 else 664 { 665 newColumns.Add(new KeyValuePair<string, DbExpression>(propertyExpr.Property.Name, propertyExpr)); 666 } 667 } 668 669 if (isAnyPropertyChanged) 670 { 671 newExpr = DbExpressionBuilder.NewRow(newColumns); 672 } 673 else 674 { 675 newExpr = argumentExpr; 676 } 677 break; 678 } 679 } 680 681 // If the expression has not changed, return the original expression 682 if (newExpr!= null && newExpr != argumentExpr) 683 { 684 return parent.NormalizeSetSource(newExpr); 685 } 686 else 687 { 688 return parent.NormalizeSetSource(argumentExpr); 689 } 690 } 691 692 693 /// <summary> 694 /// Removes casts where possible, for example Cast from a Reference type to Object type 695 /// Handles nested converts recursively. Removing no-op casts is required to prevent the 696 /// expression converter from complaining. 697 /// </summary> 698 /// <param name="functionArg"></param> 699 /// <returns></returns> UnwrapNoOpConverts(Expression expression)700 private Expression UnwrapNoOpConverts(Expression expression) 701 { 702 if (expression.NodeType == ExpressionType.Convert) 703 { 704 UnaryExpression convertExpression = (UnaryExpression)expression; 705 706 // Unwrap the operand before checking assignability for a "postfix" rewrite. 707 // The modified conversion tree is constructed bottom-up. 708 Expression operand = UnwrapNoOpConverts(convertExpression.Operand); 709 if (expression.Type.IsAssignableFrom(operand.Type)) 710 { 711 return operand; 712 } 713 } 714 return expression; 715 } 716 717 /// <summary> 718 /// Checks if the return type specified by the call expression matches that expected by the 719 /// function definition. Performs a recursive check in case of Collection type. 720 /// </summary> 721 /// <param name="result">DbFunctionExpression for the function definition</param> 722 /// <param name="actualReturnType">Return type expected by the function definition</param> 723 /// <param name="parent"></param> 724 /// <param name="call">LINQ MethodCallExpression</param> 725 /// <param name="clrReturnType">Return type specified by the call</param> 726 /// <param name="isElementOfCollection">Indicates if current call is for an Element of a Collection type</param> 727 /// <returns>DbFunctionExpression with aligned return types</returns> ValidateReturnType(DbExpression result, TypeUsage actualReturnType, ExpressionConverter parent, MethodCallExpression call, Type clrReturnType, bool isElementOfCollection)728 private DbExpression ValidateReturnType(DbExpression result, TypeUsage actualReturnType, ExpressionConverter parent, MethodCallExpression call, Type clrReturnType, bool isElementOfCollection) 729 { 730 BuiltInTypeKind modelType = actualReturnType.EdmType.BuiltInTypeKind; 731 switch (modelType) 732 { 733 case BuiltInTypeKind.CollectionType: 734 { 735 //Verify if this is a collection type (if so, recursively resolve) 736 if (!clrReturnType.IsGenericType) 737 { 738 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_EdmFunctionAttributedFunctionWithWrongReturnType(call.Method, call.Method.DeclaringType)); 739 } 740 Type genericType = clrReturnType.GetGenericTypeDefinition(); 741 if ((genericType != typeof(IEnumerable<>)) && (genericType != typeof(IQueryable<>))) 742 { 743 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_EdmFunctionAttributedFunctionWithWrongReturnType(call.Method, call.Method.DeclaringType)); 744 } 745 Type elementType = clrReturnType.GetGenericArguments()[0]; 746 result = ValidateReturnType(result, TypeHelpers.GetElementTypeUsage(actualReturnType), parent, call, elementType, true); 747 break; 748 } 749 case BuiltInTypeKind.RowType: 750 { 751 if (clrReturnType != typeof(DbDataRecord)) 752 { 753 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_EdmFunctionAttributedFunctionWithWrongReturnType(call.Method, call.Method.DeclaringType)); 754 } 755 break; 756 } 757 case BuiltInTypeKind.RefType: 758 { 759 if (clrReturnType != typeof(EntityKey)) 760 { 761 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_EdmFunctionAttributedFunctionWithWrongReturnType(call.Method, call.Method.DeclaringType)); 762 } 763 break; 764 } 765 //Handles Primitive types, Entity types and Complex types 766 default: 767 { 768 // For collection type, look for exact match of element types. 769 if (isElementOfCollection) 770 { 771 TypeUsage toType = parent.GetCastTargetType(actualReturnType, clrReturnType, null, false); 772 if (toType != null) 773 { 774 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_EdmFunctionAttributedFunctionWithWrongReturnType(call.Method, call.Method.DeclaringType)); 775 } 776 } 777 778 // Check whether the return type specified by the call can be aligned 779 // with the actual return type of the function 780 TypeUsage expectedReturnType = parent.GetValueLayerType(clrReturnType); 781 if (!TypeSemantics.IsPromotableTo(actualReturnType, expectedReturnType)) 782 { 783 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_EdmFunctionAttributedFunctionWithWrongReturnType(call.Method, call.Method.DeclaringType)); 784 } 785 786 // For scalar return types, align the return types if needed. 787 if (!isElementOfCollection) 788 { 789 result = parent.AlignTypes(result, clrReturnType); 790 } 791 break; 792 } 793 } 794 return result; 795 } 796 797 /// <summary> 798 /// Validates that the given parameterValue is not null or empty. 799 /// </summary> 800 /// <param name="call"></param> 801 /// <param name="parameterValue"></param> 802 /// <param name="parameterName"></param> ValidateFunctionAttributeParameter(MethodCallExpression call, string parameterValue, string parameterName)803 internal static void ValidateFunctionAttributeParameter(MethodCallExpression call, string parameterValue, string parameterName) 804 { 805 if (String.IsNullOrEmpty(parameterValue)) 806 { 807 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_EdmFunctionAttributeParameterNameNotValid(call.Method, call.Method.DeclaringType, parameterName)); 808 } 809 } 810 } 811 812 private sealed class CanonicalFunctionDefaultTranslator : CallTranslator 813 { CanonicalFunctionDefaultTranslator()814 internal CanonicalFunctionDefaultTranslator() 815 : base(GetMethods()) { } 816 GetMethods()817 private static IEnumerable<MethodInfo> GetMethods() 818 { 819 var result = new List<MethodInfo> 820 { 821 //Math functions 822 typeof(Math).GetMethod("Ceiling", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null), 823 typeof(Math).GetMethod("Ceiling", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(double) }, null), 824 typeof(Math).GetMethod("Floor", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null), 825 typeof(Math).GetMethod("Floor", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(double) }, null), 826 typeof(Math).GetMethod("Round", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null), 827 typeof(Math).GetMethod("Round", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(double) }, null), 828 typeof(Math).GetMethod("Round", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal), typeof(int) }, null), 829 typeof(Math).GetMethod("Round", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(double), typeof(int) }, null), 830 831 //Decimal functions 832 typeof(Decimal).GetMethod("Floor", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null), 833 typeof(Decimal).GetMethod("Ceiling", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null), 834 typeof(Decimal).GetMethod("Round", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null), 835 typeof(Decimal).GetMethod("Round", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal), typeof(int) }, null), 836 837 //String functions 838 typeof(String).GetMethod("Replace", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(String), typeof(String) }, null), 839 typeof(String).GetMethod("ToLower", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { }, null), 840 typeof(String).GetMethod("ToUpper", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { }, null), 841 typeof(String).GetMethod("Trim", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { }, null), 842 }; 843 844 // Math.Abs 845 foreach (Type argType in new [] { typeof(decimal), typeof(double), typeof(float), typeof(int), typeof(long), typeof(sbyte), typeof(short) }) 846 { 847 result.Add(typeof(Math).GetMethod("Abs", BindingFlags.Public | BindingFlags.Static, null, new Type[] { argType }, null)); 848 } 849 850 return result; 851 } 852 853 // Default translator for method calls into canonical functions. 854 // Translation: 855 // MethodName(arg1, arg2, .., argn) -> MethodName(arg1, arg2, .., argn) 856 // this.MethodName(arg1, arg2, .., argn) -> MethodName(this, arg1, arg2, .., argn) Translate(ExpressionConverter parent, MethodCallExpression call)857 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 858 { 859 LinqExpression[] linqArguments; 860 861 if (!call.Method.IsStatic) 862 { 863 Debug.Assert(call.Object != null, "Instance method without this"); 864 List<LinqExpression> arguments = new List<LinqExpression>(call.Arguments.Count + 1); 865 arguments.Add(call.Object); 866 arguments.AddRange(call.Arguments); 867 linqArguments = arguments.ToArray(); 868 } 869 else 870 { 871 linqArguments = call.Arguments.ToArray(); 872 } 873 return parent.TranslateIntoCanonicalFunction(call.Method.Name, call, linqArguments); 874 } 875 } 876 877 private abstract class AsUnicodeNonUnicodeBaseFunctionTranslator : CallTranslator 878 { 879 private bool _isUnicode; AsUnicodeNonUnicodeBaseFunctionTranslator(IEnumerable<MethodInfo> methods, bool isUnicode)880 protected AsUnicodeNonUnicodeBaseFunctionTranslator(IEnumerable<MethodInfo> methods, bool isUnicode) 881 : base(methods) 882 { 883 _isUnicode = isUnicode; 884 } 885 886 // Translation: 887 // object.AsUnicode() -> object (In its TypeUsage, the unicode facet value is set to true explicitly) 888 // object.AsNonUnicode() -> object (In its TypeUsage, the unicode facet is set to false) Translate(ExpressionConverter parent, MethodCallExpression call)889 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 890 { 891 DbExpression argument = parent.TranslateExpression(call.Arguments[0]); 892 DbExpression recreatedArgument; 893 TypeUsage updatedType = argument.ResultType.ShallowCopy(new FacetValues { Unicode = _isUnicode }); 894 895 switch (argument.ExpressionKind) 896 { 897 case DbExpressionKind.Constant: 898 recreatedArgument = DbExpressionBuilder.Constant(updatedType, ((DbConstantExpression)argument).Value); 899 break; 900 case DbExpressionKind.ParameterReference: 901 recreatedArgument = DbExpressionBuilder.Parameter(updatedType, ((DbParameterReferenceExpression)argument).ParameterName); 902 break; 903 case DbExpressionKind.Null: 904 recreatedArgument = DbExpressionBuilder.Null(updatedType); 905 break; 906 default: 907 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedAsUnicodeAndAsNonUnicode(call.Method)); 908 } 909 return recreatedArgument; 910 } 911 } 912 private sealed class AsUnicodeFunctionTranslator : AsUnicodeNonUnicodeBaseFunctionTranslator 913 { AsUnicodeFunctionTranslator()914 internal AsUnicodeFunctionTranslator() 915 : base(GetMethods(), true) { } 916 GetMethods()917 private static IEnumerable<MethodInfo> GetMethods() 918 { 919 yield return typeof(EntityFunctions).GetMethod(ExpressionConverter.AsUnicode, BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null); 920 } 921 } 922 923 private sealed class AsNonUnicodeFunctionTranslator : AsUnicodeNonUnicodeBaseFunctionTranslator 924 { AsNonUnicodeFunctionTranslator()925 internal AsNonUnicodeFunctionTranslator() 926 : base(GetMethods(), false) { } 927 GetMethods()928 private static IEnumerable<MethodInfo> GetMethods() 929 { 930 yield return typeof(EntityFunctions).GetMethod(ExpressionConverter.AsNonUnicode, BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null); 931 } 932 } 933 934 #region System.Math method translators 935 private sealed class MathPowerTranslator : CallTranslator 936 { MathPowerTranslator()937 internal MathPowerTranslator() 938 : base(new[] 939 { 940 typeof(Math).GetMethod("Pow", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(double), typeof(double) }, null), 941 }) 942 { 943 } 944 Translate(ExpressionConverter parent, MethodCallExpression call)945 internal override DbExpression Translate(ExpressionConverter parent, MethodCallExpression call) 946 { 947 DbExpression arg1 = parent.TranslateExpression(call.Arguments[0]); 948 DbExpression arg2 = parent.TranslateExpression(call.Arguments[1]); 949 return arg1.Power(arg2); 950 } 951 } 952 #endregion 953 954 #region System.Guid method translators 955 private sealed class GuidNewGuidTranslator : CallTranslator 956 { GuidNewGuidTranslator()957 internal GuidNewGuidTranslator() 958 : base(new[] 959 { 960 typeof(Guid).GetMethod("NewGuid", BindingFlags.Public | BindingFlags.Static, null, Type.EmptyTypes, null), 961 }) 962 { 963 } 964 Translate(ExpressionConverter parent, MethodCallExpression call)965 internal override DbExpression Translate(ExpressionConverter parent, MethodCallExpression call) 966 { 967 return EdmFunctions.NewGuid(); 968 } 969 } 970 #endregion 971 972 #region System.String Method Translators 973 private sealed class StringContainsTranslator : CallTranslator 974 { StringContainsTranslator()975 internal StringContainsTranslator() 976 : base(GetMethods()) { } 977 GetMethods()978 private static IEnumerable<MethodInfo> GetMethods() 979 { 980 yield return typeof(String).GetMethod("Contains", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null); 981 } 982 983 // Translation: 984 // object.EndsWith(argument) -> 985 // 1) if argument is a constant or parameter and the provider supports escaping: 986 // object like "%" + argument1 + "%", where argument1 is argument escaped by the provider 987 // 2) Otherwise: 988 // object.Contains(argument) -> IndexOf(argument, object) > 0 Translate(ExpressionConverter parent, MethodCallExpression call)989 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 990 { 991 return parent.TranslateFunctionIntoLike(call, true, true, CreateDefaultTranslation); 992 } 993 994 // DefaultTranslation: 995 // object.Contains(argument) -> IndexOf(argument, object) > 0 CreateDefaultTranslation(ExpressionConverter parent, MethodCallExpression call, DbExpression patternExpression, DbExpression inputExpression)996 private static DbExpression CreateDefaultTranslation(ExpressionConverter parent, MethodCallExpression call, DbExpression patternExpression, DbExpression inputExpression) 997 { 998 DbFunctionExpression indexOfExpression = parent.CreateCanonicalFunction(ExpressionConverter.IndexOf, call, patternExpression, inputExpression); 999 DbComparisonExpression comparisonExpression = indexOfExpression.GreaterThan(DbExpressionBuilder.Constant(0)); 1000 return comparisonExpression; 1001 } 1002 } 1003 private sealed class IndexOfTranslator : CallTranslator 1004 { IndexOfTranslator()1005 internal IndexOfTranslator() 1006 : base(GetMethods()) { } 1007 GetMethods()1008 private static IEnumerable<MethodInfo> GetMethods() 1009 { 1010 yield return typeof(String).GetMethod("IndexOf", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null); 1011 } 1012 1013 // Translation: 1014 // IndexOf(arg1) -> IndexOf(arg1, this) - 1 Translate(ExpressionConverter parent, MethodCallExpression call)1015 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 1016 { 1017 Debug.Assert(call.Arguments.Count == 1, "Expecting 1 argument for String.IndexOf"); 1018 1019 DbFunctionExpression indexOfExpression = parent.TranslateIntoCanonicalFunction(ExpressionConverter.IndexOf, call, call.Arguments[0], call.Object); 1020 CqtExpression minusExpression = indexOfExpression.Minus(DbExpressionBuilder.Constant(1)); 1021 1022 return minusExpression; 1023 } 1024 } 1025 private sealed class StartsWithTranslator : CallTranslator 1026 { StartsWithTranslator()1027 internal StartsWithTranslator() 1028 : base(GetMethods()) { } 1029 GetMethods()1030 private static IEnumerable<MethodInfo> GetMethods() 1031 { 1032 yield return typeof(String).GetMethod("StartsWith", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null); 1033 } 1034 1035 // Translation: 1036 // object.StartsWith(argument) -> 1037 // 1) if argument is a constant or parameter and the provider supports escaping: 1038 // object like argument1 + "%", where argument1 is argument escaped by the provider 1039 // 2) otherwise: 1040 // IndexOf(argument, object) == 1 Translate(ExpressionConverter parent, MethodCallExpression call)1041 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 1042 { 1043 return parent.TranslateFunctionIntoLike(call, false, true, CreateDefaultTranslation); 1044 } 1045 1046 // Default translation: 1047 // object.StartsWith(argument) -> IndexOf(argument, object) == 1 CreateDefaultTranslation(ExpressionConverter parent, MethodCallExpression call, DbExpression patternExpression, DbExpression inputExpression)1048 private static DbExpression CreateDefaultTranslation(ExpressionConverter parent, MethodCallExpression call, DbExpression patternExpression, DbExpression inputExpression) 1049 { 1050 DbExpression indexOfExpression = parent.CreateCanonicalFunction(ExpressionConverter.IndexOf, call, patternExpression, inputExpression) 1051 .Equal(DbExpressionBuilder.Constant(1)); 1052 return indexOfExpression; 1053 } 1054 } 1055 1056 private sealed class EndsWithTranslator : CallTranslator 1057 { EndsWithTranslator()1058 internal EndsWithTranslator() 1059 : base(GetMethods()) { } 1060 GetMethods()1061 private static IEnumerable<MethodInfo> GetMethods() 1062 { 1063 yield return typeof(String).GetMethod("EndsWith", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null); 1064 } 1065 1066 // Translation: 1067 // object.EndsWith(argument) -> 1068 // 1) if argument is a constant or parameter and the provider supports escaping: 1069 // object like "%" + argument1, where argument1 is argument escaped by the provider 1070 // 2) Otherwise: 1071 // object.EndsWith(argument) -> IndexOf(Reverse(argument), Reverse(object)) = 1 Translate(ExpressionConverter parent, MethodCallExpression call)1072 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 1073 { 1074 return parent.TranslateFunctionIntoLike(call, true, false, CreateDefaultTranslation); 1075 } 1076 1077 // Default Translation: 1078 // object.EndsWith(argument) -> IndexOf(Reverse(argument), Reverse(object)) = 1 CreateDefaultTranslation(ExpressionConverter parent, MethodCallExpression call, DbExpression patternExpression, DbExpression inputExpression)1079 private static DbExpression CreateDefaultTranslation(ExpressionConverter parent, MethodCallExpression call, DbExpression patternExpression, DbExpression inputExpression) 1080 { 1081 DbFunctionExpression reversePatternExpression = parent.CreateCanonicalFunction(ExpressionConverter.Reverse, call, patternExpression); 1082 DbFunctionExpression reverseInputExpression = parent.CreateCanonicalFunction(ExpressionConverter.Reverse, call, inputExpression); 1083 1084 DbExpression indexOfExpression = parent.CreateCanonicalFunction(ExpressionConverter.IndexOf, call, reversePatternExpression, reverseInputExpression) 1085 .Equal(DbExpressionBuilder.Constant(1)); 1086 return indexOfExpression; 1087 } 1088 } 1089 private sealed class SubstringTranslator : CallTranslator 1090 { SubstringTranslator()1091 internal SubstringTranslator() 1092 : base(GetMethods()) { } 1093 GetMethods()1094 private static IEnumerable<MethodInfo> GetMethods() 1095 { 1096 yield return typeof(String).GetMethod("Substring", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int) }, null); 1097 yield return typeof(String).GetMethod("Substring", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int), typeof(int) }, null); 1098 } 1099 1100 // Translation: 1101 // Substring(arg1) -> Substring(this, arg1+1, Length(this) - arg1)) 1102 // Substring(arg1, arg2) -> Substring(this, arg1+1, arg2) 1103 // Translate(ExpressionConverter parent, MethodCallExpression call)1104 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 1105 { 1106 Debug.Assert(call.Arguments.Count == 1 || call.Arguments.Count == 2, "Expecting 1 or 2 arguments for String.Substring"); 1107 1108 DbExpression arg1 = parent.TranslateExpression(call.Arguments[0]); 1109 1110 DbExpression target = parent.TranslateExpression(call.Object); 1111 DbExpression fromIndex = arg1.Plus(DbExpressionBuilder.Constant(1)); 1112 1113 CqtExpression length; 1114 if (call.Arguments.Count == 1) 1115 { 1116 length = parent.CreateCanonicalFunction(ExpressionConverter.Length, call, target) 1117 .Minus(arg1); 1118 } 1119 else 1120 { 1121 length = parent.TranslateExpression(call.Arguments[1]); 1122 } 1123 1124 CqtExpression substringExpression = parent.CreateCanonicalFunction(ExpressionConverter.Substring, call, target, fromIndex, length); 1125 return substringExpression; 1126 } 1127 } 1128 private sealed class RemoveTranslator : CallTranslator 1129 { RemoveTranslator()1130 internal RemoveTranslator() 1131 : base(GetMethods()) { } 1132 GetMethods()1133 private static IEnumerable<MethodInfo> GetMethods() 1134 { 1135 yield return typeof(String).GetMethod("Remove", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int) }, null); 1136 yield return typeof(String).GetMethod("Remove", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int), typeof(int) }, null); 1137 } 1138 1139 // Translation: 1140 // Remove(arg1) -> Substring(this, 1, arg1) 1141 // Remove(arg1, arg2) -> Concat(Substring(this, 1, arg1) , Substring(this, arg1 + arg2 + 1, Length(this) - (arg1 + arg2))) 1142 // Remove(arg1, arg2) is only supported if arg2 is a non-negative integer Translate(ExpressionConverter parent, MethodCallExpression call)1143 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 1144 { 1145 Debug.Assert(call.Arguments.Count == 1 || call.Arguments.Count == 2, "Expecting 1 or 2 arguments for String.Remove"); 1146 1147 DbExpression thisString = parent.TranslateExpression(call.Object); 1148 DbExpression arg1 = parent.TranslateExpression(call.Arguments[0]); 1149 1150 //Substring(this, 1, arg1) 1151 CqtExpression result = 1152 parent.CreateCanonicalFunction(ExpressionConverter.Substring, call, 1153 thisString, 1154 DbExpressionBuilder.Constant(1), 1155 arg1); 1156 1157 //Concat(result, Substring(this, (arg1 + arg2) +1, Length(this) - (arg1 + arg2))) 1158 if (call.Arguments.Count == 2) 1159 { 1160 //If there are two arguemtns, we only support cases when the second one translates to a non-negative constant 1161 CqtExpression arg2 = parent.TranslateExpression(call.Arguments[1]); 1162 if (!IsNonNegativeIntegerConstant(arg2)) 1163 { 1164 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedStringRemoveCase(call.Method, call.Method.GetParameters()[1].Name)); 1165 } 1166 1167 // Build the second substring 1168 // (arg1 + arg2) +1 1169 CqtExpression substringStartIndex = 1170 arg1.Plus(arg2).Plus(DbExpressionBuilder.Constant(1)); 1171 1172 // Length(this) - (arg1 + arg2) 1173 CqtExpression substringLength = 1174 parent.CreateCanonicalFunction(ExpressionConverter.Length, call, thisString) 1175 .Minus(arg1.Plus(arg2)); 1176 1177 // Substring(this, substringStartIndex, substringLenght) 1178 CqtExpression secondSubstring = 1179 parent.CreateCanonicalFunction(ExpressionConverter.Substring, call, 1180 thisString, 1181 substringStartIndex, 1182 substringLength); 1183 1184 // result = Concat (result, secondSubstring) 1185 result = parent.CreateCanonicalFunction(ExpressionConverter.Concat, call, result, secondSubstring); 1186 } 1187 return result; 1188 } 1189 IsNonNegativeIntegerConstant(CqtExpression argument)1190 private static bool IsNonNegativeIntegerConstant(CqtExpression argument) 1191 { 1192 // Check whether it is a constant of type Int32 1193 if (argument.ExpressionKind != DbExpressionKind.Constant || 1194 !TypeSemantics.IsPrimitiveType(argument.ResultType, PrimitiveTypeKind.Int32)) 1195 { 1196 return false; 1197 } 1198 1199 // Check whether its value is non-negative 1200 DbConstantExpression constantExpression = (DbConstantExpression)argument; 1201 int value = (int)constantExpression.Value; 1202 if (value < 0) 1203 { 1204 return false; 1205 } 1206 1207 return true; 1208 } 1209 } 1210 private sealed class InsertTranslator : CallTranslator 1211 { InsertTranslator()1212 internal InsertTranslator() 1213 : base(GetMethods()) { } 1214 GetMethods()1215 private static IEnumerable<MethodInfo> GetMethods() 1216 { 1217 yield return typeof(String).GetMethod("Insert", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int), typeof(string) }, null); 1218 } 1219 1220 // Translation: 1221 // Insert(startIndex, value) -> Concat(Concat(Substring(this, 1, startIndex), value), Substring(this, startIndex+1, Length(this) - startIndex)) Translate(ExpressionConverter parent, MethodCallExpression call)1222 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 1223 { 1224 Debug.Assert(call.Arguments.Count == 2, "Expecting 2 arguments for String.Insert"); 1225 1226 //Substring(this, 1, startIndex) 1227 DbExpression thisString = parent.TranslateExpression(call.Object); 1228 DbExpression arg1 = parent.TranslateExpression(call.Arguments[0]); 1229 CqtExpression firstSubstring = 1230 parent.CreateCanonicalFunction(ExpressionConverter.Substring, call, 1231 thisString, 1232 DbExpressionBuilder.Constant(1), 1233 arg1); 1234 1235 //Substring(this, startIndex+1, Length(this) - startIndex) 1236 CqtExpression secondSubstring = 1237 parent.CreateCanonicalFunction(ExpressionConverter.Substring, call, 1238 thisString, 1239 arg1.Plus(DbExpressionBuilder.Constant(1)), 1240 parent.CreateCanonicalFunction(ExpressionConverter.Length, call, thisString) 1241 .Minus(arg1)); 1242 1243 // result = Concat( Concat (firstSubstring, value), secondSubstring ) 1244 DbExpression arg2 = parent.TranslateExpression(call.Arguments[1]); 1245 CqtExpression result = parent.CreateCanonicalFunction(ExpressionConverter.Concat, call, 1246 parent.CreateCanonicalFunction(ExpressionConverter.Concat, call, 1247 firstSubstring, 1248 arg2), 1249 secondSubstring); 1250 return result; 1251 } 1252 } 1253 private sealed class IsNullOrEmptyTranslator : CallTranslator 1254 { IsNullOrEmptyTranslator()1255 internal IsNullOrEmptyTranslator() 1256 : base(GetMethods()) { } 1257 GetMethods()1258 private static IEnumerable<MethodInfo> GetMethods() 1259 { 1260 yield return typeof(String).GetMethod("IsNullOrEmpty", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null); 1261 } 1262 1263 // Translation: 1264 // IsNullOrEmpty(value) -> (IsNull(value)) OR Length(value) = 0 Translate(ExpressionConverter parent, MethodCallExpression call)1265 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 1266 { 1267 Debug.Assert(call.Arguments.Count == 1, "Expecting 1 argument for String.IsNullOrEmpty"); 1268 1269 //IsNull(value) 1270 DbExpression value = parent.TranslateExpression(call.Arguments[0]); 1271 CqtExpression isNullExpression = value.IsNull(); 1272 1273 //Length(value) = 0 1274 CqtExpression emptyStringExpression = 1275 parent.CreateCanonicalFunction(ExpressionConverter.Length, call, value) 1276 .Equal(DbExpressionBuilder.Constant(0)); 1277 1278 CqtExpression result = isNullExpression.Or(emptyStringExpression); 1279 return result; 1280 } 1281 } 1282 private sealed class StringConcatTranslator : CallTranslator 1283 { StringConcatTranslator()1284 internal StringConcatTranslator() 1285 : base(GetMethods()) { } 1286 GetMethods()1287 private static IEnumerable<MethodInfo> GetMethods() 1288 { 1289 yield return typeof(String).GetMethod("Concat", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(string) }, null); 1290 yield return typeof(String).GetMethod("Concat", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(string), typeof(string) }, null); 1291 yield return typeof(String).GetMethod("Concat", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(string), typeof(string), typeof(string) }, null); 1292 } 1293 1294 // Translation: 1295 // Concat (arg1, arg2) -> Concat(arg1, arg2) 1296 // Concat (arg1, arg2, arg3) -> Concat(Concat(arg1, arg2), arg3) 1297 // Concat (arg1, arg2, arg3, arg4) -> Concat(Concat(Concat(arg1, arg2), arg3), arg4) Translate(ExpressionConverter parent, MethodCallExpression call)1298 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 1299 { 1300 Debug.Assert(call.Arguments.Count >= 2 && call.Arguments.Count <= 4, "Expecting between 2 and 4 arguments for String.Concat"); 1301 1302 CqtExpression result = parent.TranslateExpression(call.Arguments[0]); 1303 for (int argIndex = 1; argIndex < call.Arguments.Count; argIndex++) 1304 { 1305 // result = Concat(result, arg[argIndex]) 1306 result = parent.CreateCanonicalFunction(ExpressionConverter.Concat, call, 1307 result, 1308 parent.TranslateExpression(call.Arguments[argIndex])); 1309 } 1310 return result; 1311 } 1312 } 1313 private abstract class TrimBaseTranslator : CallTranslator 1314 { 1315 private string _canonicalFunctionName; TrimBaseTranslator(IEnumerable<MethodInfo> methods, string canonicalFunctionName)1316 protected TrimBaseTranslator(IEnumerable<MethodInfo> methods, string canonicalFunctionName) 1317 : base(methods) 1318 { 1319 _canonicalFunctionName = canonicalFunctionName; 1320 } 1321 1322 // Translation: 1323 // object.MethodName -> CanonicalFunctionName(object) 1324 // Supported only if the argument is an empty array. Translate(ExpressionConverter parent, MethodCallExpression call)1325 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 1326 { 1327 if (!IsEmptyArray(call.Arguments[0])) 1328 { 1329 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedTrimStartTrimEndCase(call.Method)); 1330 } 1331 1332 return parent.TranslateIntoCanonicalFunction(_canonicalFunctionName, call, call.Object); 1333 } 1334 IsEmptyArray(LinqExpression expression)1335 internal static bool IsEmptyArray(LinqExpression expression) 1336 { 1337 if (expression.NodeType == ExpressionType.NewArrayInit) 1338 { 1339 NewArrayExpression newArray = (NewArrayExpression)expression; 1340 if (newArray.Expressions.Count == 0) 1341 { 1342 return true; 1343 } 1344 } 1345 else if (expression.NodeType == ExpressionType.NewArrayBounds) 1346 { 1347 // To be empty, the array must have rank 1 with a single bound of 0 1348 NewArrayExpression newArray = (NewArrayExpression)expression; 1349 if (newArray.Expressions.Count == 1 && 1350 newArray.Expressions[0].NodeType == ExpressionType.Constant) 1351 { 1352 return object.Equals(((ConstantExpression)newArray.Expressions[0]).Value, 0); 1353 } 1354 } 1355 return false; 1356 } 1357 } 1358 private sealed class TrimTranslator : TrimBaseTranslator 1359 { TrimTranslator()1360 internal TrimTranslator() 1361 : base(GetMethods(), ExpressionConverter.Trim) { } 1362 GetMethods()1363 private static IEnumerable<MethodInfo> GetMethods() 1364 { 1365 yield return typeof(String).GetMethod("Trim", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(Char[]) }, null); 1366 } 1367 } 1368 private sealed class TrimStartTranslator : TrimBaseTranslator 1369 { TrimStartTranslator()1370 internal TrimStartTranslator() 1371 : base(GetMethods(), ExpressionConverter.LTrim) { } 1372 GetMethods()1373 private static IEnumerable<MethodInfo> GetMethods() 1374 { 1375 yield return typeof(String).GetMethod("TrimStart", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(Char[]) }, null); 1376 } 1377 } 1378 private sealed class TrimEndTranslator : TrimBaseTranslator 1379 { TrimEndTranslator()1380 internal TrimEndTranslator() 1381 : base(GetMethods(), ExpressionConverter.RTrim) { } 1382 GetMethods()1383 private static IEnumerable<MethodInfo> GetMethods() 1384 { 1385 yield return typeof(String).GetMethod("TrimEnd", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(Char[]) }, null); 1386 } 1387 } 1388 #endregion 1389 1390 #region Visual Basic Specific Translators 1391 private sealed class VBCanonicalFunctionDefaultTranslator : CallTranslator 1392 { 1393 private const string s_stringsTypeFullName = "Microsoft.VisualBasic.Strings"; 1394 private const string s_dateAndTimeTypeFullName = "Microsoft.VisualBasic.DateAndTime"; 1395 VBCanonicalFunctionDefaultTranslator(Assembly vbAssembly)1396 internal VBCanonicalFunctionDefaultTranslator(Assembly vbAssembly) 1397 : base(GetMethods(vbAssembly)) { } 1398 GetMethods(Assembly vbAssembly)1399 private static IEnumerable<MethodInfo> GetMethods(Assembly vbAssembly) 1400 { 1401 //Strings Types 1402 Type stringsType = vbAssembly.GetType(s_stringsTypeFullName); 1403 yield return stringsType.GetMethod("Trim", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null); 1404 yield return stringsType.GetMethod("LTrim", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null); 1405 yield return stringsType.GetMethod("RTrim", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null); 1406 yield return stringsType.GetMethod("Left", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int) }, null); 1407 yield return stringsType.GetMethod("Right", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int) }, null); 1408 1409 //DateTimeType 1410 Type dateTimeType = vbAssembly.GetType(s_dateAndTimeTypeFullName); 1411 yield return dateTimeType.GetMethod("Year", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null); 1412 yield return dateTimeType.GetMethod("Month", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null); 1413 yield return dateTimeType.GetMethod("Day", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null); 1414 yield return dateTimeType.GetMethod("Hour", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null); 1415 yield return dateTimeType.GetMethod("Minute", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null); 1416 yield return dateTimeType.GetMethod("Second", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null); 1417 } 1418 1419 // Default translator for vb static method calls into canonical functions. 1420 // Translation: 1421 // MethodName(arg1, arg2, .., argn) -> MethodName(arg1, arg2, .., argn) Translate(ExpressionConverter parent, MethodCallExpression call)1422 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 1423 { 1424 return parent.TranslateIntoCanonicalFunction(call.Method.Name, call, call.Arguments.ToArray()); 1425 } 1426 } 1427 private sealed class VBCanonicalFunctionRenameTranslator : CallTranslator 1428 { 1429 private const string s_stringsTypeFullName = "Microsoft.VisualBasic.Strings"; 1430 private static readonly Dictionary<MethodInfo, string> s_methodNameMap = new Dictionary<MethodInfo, string>(4); 1431 VBCanonicalFunctionRenameTranslator(Assembly vbAssembly)1432 internal VBCanonicalFunctionRenameTranslator(Assembly vbAssembly) 1433 : base(GetMethods(vbAssembly)) { } 1434 GetMethods(Assembly vbAssembly)1435 private static IEnumerable<MethodInfo> GetMethods(Assembly vbAssembly) 1436 { 1437 //Strings Types 1438 Type stringsType = vbAssembly.GetType(s_stringsTypeFullName); 1439 yield return GetMethod(stringsType, "Len", ExpressionConverter.Length, new Type[] { typeof(string) }); 1440 yield return GetMethod(stringsType, "Mid", ExpressionConverter.Substring, new Type[] { typeof(string), typeof(int), typeof(int) }); 1441 yield return GetMethod(stringsType, "UCase", ExpressionConverter.ToUpper, new Type[] { typeof(string) }); 1442 yield return GetMethod(stringsType, "LCase", ExpressionConverter.ToLower, new Type[] { typeof(string) }); 1443 } 1444 GetMethod(Type declaringType, string methodName, string canonicalFunctionName, Type[] argumentTypes)1445 private static MethodInfo GetMethod(Type declaringType, string methodName, string canonicalFunctionName, Type[] argumentTypes) 1446 { 1447 MethodInfo methodInfo = declaringType.GetMethod(methodName, BindingFlags.Public | BindingFlags.Static, null, argumentTypes, null); 1448 s_methodNameMap.Add(methodInfo, canonicalFunctionName); 1449 return methodInfo; 1450 } 1451 1452 // Translator for static method calls into canonical functions when only the name of the canonical function 1453 // is different from the name of the method, but the argumens match. 1454 // Translation: 1455 // MethodName(arg1, arg2, .., argn) -> CanonicalFunctionName(arg1, arg2, .., argn) Translate(ExpressionConverter parent, MethodCallExpression call)1456 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 1457 { 1458 return parent.TranslateIntoCanonicalFunction(s_methodNameMap[call.Method], call, call.Arguments.ToArray()); 1459 } 1460 } 1461 private sealed class VBDatePartTranslator : CallTranslator 1462 { 1463 private const string s_dateAndTimeTypeFullName = "Microsoft.VisualBasic.DateAndTime"; 1464 private const string s_DateIntervalFullName = "Microsoft.VisualBasic.DateInterval"; 1465 private const string s_FirstDayOfWeekFullName = "Microsoft.VisualBasic.FirstDayOfWeek"; 1466 private const string s_FirstWeekOfYearFullName = "Microsoft.VisualBasic.FirstWeekOfYear"; 1467 private static HashSet<string> s_supportedIntervals; 1468 VBDatePartTranslator(Assembly vbAssembly)1469 internal VBDatePartTranslator(Assembly vbAssembly) 1470 : base(GetMethods(vbAssembly)) { } 1471 VBDatePartTranslator()1472 static VBDatePartTranslator() 1473 { 1474 s_supportedIntervals = new HashSet<string>(); 1475 s_supportedIntervals.Add(ExpressionConverter.Year); 1476 s_supportedIntervals.Add(ExpressionConverter.Month); 1477 s_supportedIntervals.Add(ExpressionConverter.Day); 1478 s_supportedIntervals.Add(ExpressionConverter.Hour); 1479 s_supportedIntervals.Add(ExpressionConverter.Minute); 1480 s_supportedIntervals.Add(ExpressionConverter.Second); 1481 } 1482 GetMethods(Assembly vbAssembly)1483 private static IEnumerable<MethodInfo> GetMethods(Assembly vbAssembly) 1484 { 1485 Type dateAndTimeType = vbAssembly.GetType(s_dateAndTimeTypeFullName); 1486 Type dateIntervalEnum = vbAssembly.GetType(s_DateIntervalFullName); 1487 Type firstDayOfWeekEnum = vbAssembly.GetType(s_FirstDayOfWeekFullName); 1488 Type firstWeekOfYearEnum = vbAssembly.GetType(s_FirstWeekOfYearFullName); 1489 1490 yield return dateAndTimeType.GetMethod("DatePart", BindingFlags.Public | BindingFlags.Static, null, 1491 new Type[] { dateIntervalEnum, typeof(DateTime), firstDayOfWeekEnum, firstWeekOfYearEnum }, null); 1492 } 1493 1494 // Translation: 1495 // DatePart(DateInterval, date, arg3, arg4) -> 'DateInterval'(date) 1496 // Note: it is only supported for the values of DateInterval listed in s_supportedIntervals. Translate(ExpressionConverter parent, MethodCallExpression call)1497 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 1498 { 1499 Debug.Assert(call.Arguments.Count == 4, "Expecting 4 arguments for Microsoft.VisualBasic.DateAndTime.DatePart"); 1500 1501 ConstantExpression intervalLinqExpression = call.Arguments[0] as ConstantExpression; 1502 if (intervalLinqExpression == null) 1503 { 1504 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedVBDatePartNonConstantInterval(call.Method, call.Method.GetParameters()[0].Name)); 1505 } 1506 1507 string intervalValue = intervalLinqExpression.Value.ToString(); 1508 if (!s_supportedIntervals.Contains(intervalValue)) 1509 { 1510 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedVBDatePartInvalidInterval(call.Method, call.Method.GetParameters()[0].Name, intervalValue)); 1511 } 1512 1513 CqtExpression result = parent.TranslateIntoCanonicalFunction(intervalValue, call, call.Arguments[1]); 1514 return result; 1515 } 1516 } 1517 #endregion 1518 #endregion 1519 1520 #region Sequence method translators 1521 private abstract class SequenceMethodTranslator 1522 { 1523 private readonly IEnumerable<SequenceMethod> _methods; SequenceMethodTranslator(params SequenceMethod[] methods)1524 protected SequenceMethodTranslator(params SequenceMethod[] methods) { _methods = methods; } 1525 internal IEnumerable<SequenceMethod> Methods { get { return _methods; } } Translate(ExpressionConverter parent, MethodCallExpression call, SequenceMethod sequenceMethod)1526 internal virtual CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call, SequenceMethod sequenceMethod) 1527 { 1528 return Translate(parent, call); 1529 } Translate(ExpressionConverter parent, MethodCallExpression call)1530 internal abstract CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call); ToString()1531 public override string ToString() 1532 { 1533 return GetType().Name; 1534 } 1535 } 1536 private abstract class PagingTranslator : UnarySequenceMethodTranslator 1537 { PagingTranslator(params SequenceMethod[] methods)1538 protected PagingTranslator(params SequenceMethod[] methods) : base(methods) { } TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call)1539 protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call) 1540 { 1541 // translate count expression 1542 Debug.Assert(call.Arguments.Count == 2, "Skip and Take must have 2 arguments"); 1543 LinqExpression linqCount = call.Arguments[1]; 1544 CqtExpression count = parent.TranslateExpression(linqCount); 1545 1546 // translate paging expression 1547 DbExpression result = TranslatePagingOperator(parent, operand, count); 1548 1549 return result; 1550 } TranslatePagingOperator(ExpressionConverter parent, CqtExpression operand, CqtExpression count)1551 protected abstract CqtExpression TranslatePagingOperator(ExpressionConverter parent, CqtExpression operand, CqtExpression count); 1552 } 1553 private sealed class TakeTranslator : PagingTranslator 1554 { TakeTranslator()1555 internal TakeTranslator() : base(SequenceMethod.Take) { } TranslatePagingOperator(ExpressionConverter parent, CqtExpression operand, CqtExpression count)1556 protected override CqtExpression TranslatePagingOperator(ExpressionConverter parent, CqtExpression operand, CqtExpression count) 1557 { 1558 return parent.Limit(operand, count); 1559 } 1560 } 1561 private sealed class SkipTranslator : PagingTranslator 1562 { SkipTranslator()1563 internal SkipTranslator() : base(SequenceMethod.Skip) { } TranslatePagingOperator(ExpressionConverter parent, CqtExpression operand, CqtExpression count)1564 protected override CqtExpression TranslatePagingOperator(ExpressionConverter parent, CqtExpression operand, CqtExpression count) 1565 { 1566 return parent.Skip(operand.BindAs(parent.AliasGenerator.Next()), count); 1567 } 1568 } 1569 private sealed class JoinTranslator : SequenceMethodTranslator 1570 { JoinTranslator()1571 internal JoinTranslator() : base(SequenceMethod.Join) { } Translate(ExpressionConverter parent, MethodCallExpression call)1572 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 1573 { 1574 Debug.Assert(5 == call.Arguments.Count); 1575 // get expressions describing inputs to the join 1576 CqtExpression outer = parent.TranslateSet(call.Arguments[0]); 1577 CqtExpression inner = parent.TranslateSet(call.Arguments[1]); 1578 1579 // get expressions describing key selectors 1580 LambdaExpression outerLambda = parent.GetLambdaExpression(call, 2); 1581 LambdaExpression innerLambda = parent.GetLambdaExpression(call, 3); 1582 1583 // get outer selector expression 1584 LambdaExpression selectorLambda = parent.GetLambdaExpression(call, 4); 1585 1586 // check if the selector is a trivial rename such as 1587 // select outer as m, inner as n from (...) as outer join (...) as inner on ... 1588 // In case of the trivial rename, simply name the join inputs as m and n, 1589 // otherwise generate a projection for the selector. 1590 string outerBindingName; 1591 string innerBindingName; 1592 InitializerMetadata initializerMetadata; 1593 var selectorLambdaIsTrivialRename = IsTrivialRename(selectorLambda, parent, out outerBindingName, out innerBindingName, out initializerMetadata); 1594 1595 // translator key selectors 1596 DbExpressionBinding outerBinding; 1597 DbExpressionBinding innerBinding; 1598 CqtExpression outerKeySelector = selectorLambdaIsTrivialRename ? 1599 parent.TranslateLambda(outerLambda, outer, outerBindingName, out outerBinding) : 1600 parent.TranslateLambda(outerLambda, outer, out outerBinding); 1601 CqtExpression innerKeySelector = selectorLambdaIsTrivialRename ? 1602 parent.TranslateLambda(innerLambda, inner, innerBindingName, out innerBinding) : 1603 parent.TranslateLambda(innerLambda, inner, out innerBinding); 1604 1605 // construct join expression 1606 if (!TypeSemantics.IsEqualComparable(outerKeySelector.ResultType) || 1607 !TypeSemantics.IsEqualComparable(innerKeySelector.ResultType)) 1608 { 1609 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedKeySelector(call.Method.Name)); 1610 } 1611 1612 var joinCondition = parent.CreateEqualsExpression(outerKeySelector, innerKeySelector, EqualsPattern.PositiveNullEqualityNonComposable, outerLambda.Body.Type, innerLambda.Body.Type); 1613 1614 // In case of trivial rename create and return the join expression, 1615 // otherwise continue with generation of the selector projection. 1616 if (selectorLambdaIsTrivialRename) 1617 { 1618 var resultType = TypeUsage.Create(TypeHelpers.CreateRowType( 1619 new List<KeyValuePair<string, TypeUsage>>() 1620 { 1621 new KeyValuePair<string, TypeUsage>(outerBinding.VariableName, outerBinding.VariableType), 1622 new KeyValuePair<string, TypeUsage>(innerBinding.VariableName, innerBinding.VariableType) 1623 }, 1624 initializerMetadata)); 1625 1626 return new DbJoinExpression(DbExpressionKind.InnerJoin, TypeUsage.Create(TypeHelpers.CreateCollectionType(resultType)), outerBinding, innerBinding, joinCondition); 1627 } 1628 1629 DbJoinExpression join = outerBinding.InnerJoin(innerBinding, joinCondition); 1630 1631 // generate the projection for the non-trivial selector. 1632 DbExpressionBinding joinBinding = join.BindAs(parent.AliasGenerator.Next()); 1633 1634 // create property expressions for the inner and outer 1635 DbPropertyExpression joinOuter = joinBinding.Variable.Property(outerBinding.VariableName); 1636 DbPropertyExpression joinInner = joinBinding.Variable.Property(innerBinding.VariableName); 1637 1638 // push outer and inner join parts into the binding scope (the order 1639 // is irrelevant because the binding context matches based on parameter 1640 // reference rather than ordinal) 1641 parent._bindingContext.PushBindingScope(new Binding(selectorLambda.Parameters[0], joinOuter)); 1642 parent._bindingContext.PushBindingScope(new Binding(selectorLambda.Parameters[1], joinInner)); 1643 1644 // translate join selector 1645 CqtExpression selector = parent.TranslateExpression(selectorLambda.Body); 1646 1647 // pop binding scope 1648 parent._bindingContext.PopBindingScope(); 1649 parent._bindingContext.PopBindingScope(); 1650 1651 return joinBinding.Project(selector); 1652 } 1653 } 1654 private abstract class BinarySequenceMethodTranslator : SequenceMethodTranslator 1655 { BinarySequenceMethodTranslator(params SequenceMethod[] methods)1656 protected BinarySequenceMethodTranslator(params SequenceMethod[] methods) : base(methods) { } 1657 // This method is not required to be virtual (but TranslateRight has to be). This helps improve 1658 // performance as this class is used frequently during CQT generation phase. TranslateLeft(ExpressionConverter parent, LinqExpression expr)1659 protected CqtExpression TranslateLeft(ExpressionConverter parent, LinqExpression expr) 1660 { 1661 return parent.TranslateSet(expr); 1662 } TranslateRight(ExpressionConverter parent, LinqExpression expr)1663 protected virtual CqtExpression TranslateRight(ExpressionConverter parent, LinqExpression expr) 1664 { 1665 return parent.TranslateSet(expr); 1666 } Translate(ExpressionConverter parent, MethodCallExpression call)1667 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 1668 { 1669 if (null != call.Object) 1670 { 1671 // instance method 1672 Debug.Assert(1 == call.Arguments.Count); 1673 CqtExpression left = this.TranslateLeft(parent, call.Object); 1674 CqtExpression right = this.TranslateRight(parent, call.Arguments[0]); 1675 return TranslateBinary(parent, left, right); 1676 } 1677 else 1678 { 1679 // static extension method 1680 Debug.Assert(2 == call.Arguments.Count); 1681 CqtExpression left = this.TranslateLeft(parent, call.Arguments[0]); 1682 CqtExpression right = this.TranslateRight(parent, call.Arguments[1]); 1683 return TranslateBinary(parent, left, right); 1684 } 1685 } TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right)1686 protected abstract CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right); 1687 } 1688 private class ConcatTranslator : BinarySequenceMethodTranslator 1689 { ConcatTranslator()1690 internal ConcatTranslator() : base(SequenceMethod.Concat) { } TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right)1691 protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right) 1692 { 1693 return parent.UnionAll(left, right); 1694 } 1695 } 1696 private sealed class UnionTranslator : BinarySequenceMethodTranslator 1697 { UnionTranslator()1698 internal UnionTranslator() : base(SequenceMethod.Union) { } TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right)1699 protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right) 1700 { 1701 return parent.Distinct(parent.UnionAll(left, right)); 1702 } 1703 } 1704 private sealed class IntersectTranslator : BinarySequenceMethodTranslator 1705 { IntersectTranslator()1706 internal IntersectTranslator() : base(SequenceMethod.Intersect) { } TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right)1707 protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right) 1708 { 1709 return parent.Intersect(left, right); 1710 } 1711 } 1712 private sealed class ExceptTranslator : BinarySequenceMethodTranslator 1713 { ExceptTranslator()1714 internal ExceptTranslator() : base(SequenceMethod.Except) { } TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right)1715 protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right) 1716 { 1717 return parent.Except(left, right); 1718 } TranslateRight(ExpressionConverter parent, LinqExpression expr)1719 protected override CqtExpression TranslateRight(ExpressionConverter parent, LinqExpression expr) 1720 { 1721 #if DEBUG 1722 int preValue = parent.IgnoreInclude; 1723 #endif 1724 parent.IgnoreInclude++; 1725 var result = base.TranslateRight(parent, expr); 1726 parent.IgnoreInclude--; 1727 #if DEBUG 1728 Debug.Assert(preValue == parent.IgnoreInclude); 1729 #endif 1730 return result; 1731 } 1732 } 1733 private abstract class AggregateTranslator : SequenceMethodTranslator 1734 { 1735 private readonly string _functionName; 1736 private readonly bool _takesPredicate; 1737 AggregateTranslator(string functionName, bool takesPredicate, params SequenceMethod[] methods)1738 protected AggregateTranslator(string functionName, bool takesPredicate, params SequenceMethod[] methods) 1739 : base(methods) 1740 { 1741 _takesPredicate = takesPredicate; 1742 _functionName = functionName; 1743 } 1744 Translate(ExpressionConverter parent, MethodCallExpression call)1745 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 1746 { 1747 bool isUnary = 1 == call.Arguments.Count; 1748 Debug.Assert(isUnary || 2 == call.Arguments.Count); 1749 1750 CqtExpression operand = parent.TranslateSet(call.Arguments[0]); 1751 1752 if (!isUnary) 1753 { 1754 LambdaExpression lambda = parent.GetLambdaExpression(call, 1); 1755 DbExpressionBinding sourceBinding; 1756 CqtExpression cqtLambda = parent.TranslateLambda(lambda, operand, out sourceBinding); 1757 1758 if (_takesPredicate) 1759 { 1760 // treat the lambda as a filter 1761 operand = parent.Filter(sourceBinding, cqtLambda); 1762 } 1763 else 1764 { 1765 // treat the lambda as a selector 1766 operand = sourceBinding.Project(cqtLambda); 1767 } 1768 } 1769 1770 TypeUsage returnType = GetReturnType(parent, call); 1771 EdmFunction function = FindFunction(parent, call, returnType); 1772 1773 //Save the unwrapped operand for the optimized translation 1774 DbExpression unwrappedOperand = operand; 1775 1776 operand = WrapCollectionOperand(parent, operand, returnType); 1777 List<DbExpression> arguments = new List<DbExpression>(1); 1778 arguments.Add(operand); 1779 1780 DbExpression result = function.Invoke(arguments); 1781 result = parent.AlignTypes(result, call.Type); 1782 1783 return result; 1784 } 1785 GetReturnType(ExpressionConverter parent, MethodCallExpression call)1786 protected virtual TypeUsage GetReturnType(ExpressionConverter parent, MethodCallExpression call) 1787 { 1788 Debug.Assert(parent != null, "parent != null"); 1789 Debug.Assert(call != null, "call != null"); 1790 1791 return parent.GetValueLayerType(call.Type); 1792 } 1793 1794 // If necessary, wraps the operand to ensure the appropriate aggregate overload is called WrapCollectionOperand(ExpressionConverter parent, CqtExpression operand, TypeUsage returnType)1795 protected virtual CqtExpression WrapCollectionOperand(ExpressionConverter parent, CqtExpression operand, 1796 TypeUsage returnType) 1797 { 1798 // check if the operand needs to be wrapped to ensure the correct function overload is called 1799 if (!TypeUsageEquals(returnType, ((CollectionType)operand.ResultType.EdmType).TypeUsage)) 1800 { 1801 DbExpressionBinding operandCastBinding = operand.BindAs(parent.AliasGenerator.Next()); 1802 DbProjectExpression operandCastProjection = operandCastBinding.Project(operandCastBinding.Variable.CastTo(returnType)); 1803 operand = operandCastProjection; 1804 } 1805 return operand; 1806 } 1807 1808 // If necessary, wraps the operand to ensure the appropriate aggregate overload is called WrapNonCollectionOperand(ExpressionConverter parent, CqtExpression operand, TypeUsage returnType)1809 protected virtual CqtExpression WrapNonCollectionOperand(ExpressionConverter parent, CqtExpression operand, 1810 TypeUsage returnType) 1811 { 1812 if (!TypeUsageEquals(returnType, operand.ResultType)) 1813 { 1814 operand = operand.CastTo(returnType); 1815 } 1816 return operand; 1817 } 1818 1819 // Finds the best function overload given the expected return type FindFunction(ExpressionConverter parent, MethodCallExpression call, TypeUsage argumentType)1820 protected virtual EdmFunction FindFunction(ExpressionConverter parent, MethodCallExpression call, 1821 TypeUsage argumentType) 1822 { 1823 List<TypeUsage> argTypes = new List<TypeUsage>(1); 1824 // In general, we use the return type as the parameter type to align LINQ semantics 1825 // with SQL semantics, and avoid apparent loss of precision for some LINQ aggregate operators. 1826 // (e.g., AVG(1, 2) = 2.0, AVG((double)1, (double)2)) = 1.5) 1827 argTypes.Add(argumentType); 1828 1829 return parent.FindCanonicalFunction(_functionName, argTypes, true /* isGroupAggregateFunction */, call); 1830 } 1831 } 1832 private sealed class MaxTranslator : AggregateTranslator 1833 { MaxTranslator()1834 internal MaxTranslator() 1835 : base("Max", false, 1836 SequenceMethod.Max, 1837 SequenceMethod.MaxSelector, 1838 SequenceMethod.MaxInt, 1839 SequenceMethod.MaxIntSelector, 1840 SequenceMethod.MaxDecimal, 1841 SequenceMethod.MaxDecimalSelector, 1842 SequenceMethod.MaxDouble, 1843 SequenceMethod.MaxDoubleSelector, 1844 SequenceMethod.MaxLong, 1845 SequenceMethod.MaxLongSelector, 1846 SequenceMethod.MaxSingle, 1847 SequenceMethod.MaxSingleSelector, 1848 SequenceMethod.MaxNullableDecimal, 1849 SequenceMethod.MaxNullableDecimalSelector, 1850 SequenceMethod.MaxNullableDouble, 1851 SequenceMethod.MaxNullableDoubleSelector, 1852 SequenceMethod.MaxNullableInt, 1853 SequenceMethod.MaxNullableIntSelector, 1854 SequenceMethod.MaxNullableLong, 1855 SequenceMethod.MaxNullableLongSelector, 1856 SequenceMethod.MaxNullableSingle, 1857 SequenceMethod.MaxNullableSingleSelector) 1858 { 1859 } 1860 GetReturnType(ExpressionConverter parent, MethodCallExpression call)1861 protected override TypeUsage GetReturnType(ExpressionConverter parent, MethodCallExpression call) 1862 { 1863 Debug.Assert(parent != null, "parent != null"); 1864 Debug.Assert(call != null, "call != null"); 1865 1866 var returnType = base.GetReturnType(parent, call); 1867 1868 // This allows to find and use the correct overload of Max function for enums. 1869 // Note that returnType does not have to be scalar type here (error case). 1870 return TypeSemantics.IsEnumerationType(returnType) ? 1871 TypeUsage.Create(Helper.GetUnderlyingEdmTypeForEnumType(returnType.EdmType), returnType.Facets) : 1872 returnType; 1873 } 1874 } 1875 private sealed class MinTranslator : AggregateTranslator 1876 { MinTranslator()1877 internal MinTranslator() 1878 : base("Min", false, 1879 SequenceMethod.Min, 1880 SequenceMethod.MinSelector, 1881 SequenceMethod.MinDecimal, 1882 SequenceMethod.MinDecimalSelector, 1883 SequenceMethod.MinDouble, 1884 SequenceMethod.MinDoubleSelector, 1885 SequenceMethod.MinInt, 1886 SequenceMethod.MinIntSelector, 1887 SequenceMethod.MinLong, 1888 SequenceMethod.MinLongSelector, 1889 SequenceMethod.MinNullableDecimal, 1890 SequenceMethod.MinSingle, 1891 SequenceMethod.MinSingleSelector, 1892 SequenceMethod.MinNullableDecimalSelector, 1893 SequenceMethod.MinNullableDouble, 1894 SequenceMethod.MinNullableDoubleSelector, 1895 SequenceMethod.MinNullableInt, 1896 SequenceMethod.MinNullableIntSelector, 1897 SequenceMethod.MinNullableLong, 1898 SequenceMethod.MinNullableLongSelector, 1899 SequenceMethod.MinNullableSingle, 1900 SequenceMethod.MinNullableSingleSelector) 1901 { 1902 } 1903 GetReturnType(ExpressionConverter parent, MethodCallExpression call)1904 protected override TypeUsage GetReturnType(ExpressionConverter parent, MethodCallExpression call) 1905 { 1906 Debug.Assert(parent != null, "parent != null"); 1907 Debug.Assert(call != null, "call != null"); 1908 1909 var returnType = base.GetReturnType(parent, call); 1910 1911 // This allows to find and use the correct overload of Min function for enums. 1912 // Note that returnType does not have to be scalar type here (error case). 1913 return TypeSemantics.IsEnumerationType(returnType) ? 1914 TypeUsage.Create(Helper.GetUnderlyingEdmTypeForEnumType(returnType.EdmType), returnType.Facets) : 1915 returnType; 1916 } 1917 } 1918 private sealed class AverageTranslator : AggregateTranslator 1919 { AverageTranslator()1920 internal AverageTranslator() 1921 : base("Avg", false, 1922 SequenceMethod.AverageDecimal, 1923 SequenceMethod.AverageDecimalSelector, 1924 SequenceMethod.AverageDouble, 1925 SequenceMethod.AverageDoubleSelector, 1926 SequenceMethod.AverageInt, 1927 SequenceMethod.AverageIntSelector, 1928 SequenceMethod.AverageLong, 1929 SequenceMethod.AverageLongSelector, 1930 SequenceMethod.AverageSingle, 1931 SequenceMethod.AverageSingleSelector, 1932 SequenceMethod.AverageNullableDecimal, 1933 SequenceMethod.AverageNullableDecimalSelector, 1934 SequenceMethod.AverageNullableDouble, 1935 SequenceMethod.AverageNullableDoubleSelector, 1936 SequenceMethod.AverageNullableInt, 1937 SequenceMethod.AverageNullableIntSelector, 1938 SequenceMethod.AverageNullableLong, 1939 SequenceMethod.AverageNullableLongSelector, 1940 SequenceMethod.AverageNullableSingle, 1941 SequenceMethod.AverageNullableSingleSelector) 1942 { 1943 } 1944 } 1945 private sealed class SumTranslator : AggregateTranslator 1946 { SumTranslator()1947 internal SumTranslator() 1948 : base("Sum", false, 1949 SequenceMethod.SumDecimal, 1950 SequenceMethod.SumDecimalSelector, 1951 SequenceMethod.SumDouble, 1952 SequenceMethod.SumDoubleSelector, 1953 SequenceMethod.SumInt, 1954 SequenceMethod.SumIntSelector, 1955 SequenceMethod.SumLong, 1956 SequenceMethod.SumLongSelector, 1957 SequenceMethod.SumSingle, 1958 SequenceMethod.SumSingleSelector, 1959 SequenceMethod.SumNullableDecimal, 1960 SequenceMethod.SumNullableDecimalSelector, 1961 SequenceMethod.SumNullableDouble, 1962 SequenceMethod.SumNullableDoubleSelector, 1963 SequenceMethod.SumNullableInt, 1964 SequenceMethod.SumNullableIntSelector, 1965 SequenceMethod.SumNullableLong, 1966 SequenceMethod.SumNullableLongSelector, 1967 SequenceMethod.SumNullableSingle, 1968 SequenceMethod.SumNullableSingleSelector) 1969 { 1970 } 1971 } 1972 private abstract class CountTranslatorBase : AggregateTranslator 1973 { CountTranslatorBase(string functionName, params SequenceMethod[] methods)1974 protected CountTranslatorBase(string functionName, params SequenceMethod[] methods) 1975 : base(functionName, true, methods) 1976 { 1977 } WrapCollectionOperand(ExpressionConverter parent, CqtExpression operand, TypeUsage returnType)1978 protected override CqtExpression WrapCollectionOperand(ExpressionConverter parent, CqtExpression operand, TypeUsage returnType) 1979 { 1980 // always count a constant value 1981 DbProjectExpression constantProject = operand.BindAs(parent.AliasGenerator.Next()).Project(DbExpressionBuilder.Constant(1)); 1982 return constantProject; 1983 } 1984 WrapNonCollectionOperand(ExpressionConverter parent, CqtExpression operand, TypeUsage returnType)1985 protected override CqtExpression WrapNonCollectionOperand(ExpressionConverter parent, CqtExpression operand, TypeUsage returnType) 1986 { 1987 // always count a constant value 1988 DbExpression constantExpression = DbExpressionBuilder.Constant(1); 1989 if (!TypeUsageEquals(constantExpression.ResultType, returnType)) 1990 { 1991 constantExpression = constantExpression.CastTo(returnType); 1992 } 1993 return constantExpression; 1994 } FindFunction(ExpressionConverter parent, MethodCallExpression call, TypeUsage argumentType)1995 protected override EdmFunction FindFunction(ExpressionConverter parent, MethodCallExpression call, 1996 TypeUsage argumentType) 1997 { 1998 // For most ELinq aggregates, the argument type is the return type. For "count", the 1999 // argument type is always Int32, since we project a constant Int32 value in WrapCollectionOperand. 2000 TypeUsage intTypeUsage = TypeUsage.CreateDefaultTypeUsage(EdmProviderManifest.Instance.GetPrimitiveType(PrimitiveTypeKind.Int32)); 2001 return base.FindFunction(parent, call, intTypeUsage); 2002 } 2003 } 2004 private sealed class CountTranslator : CountTranslatorBase 2005 { CountTranslator()2006 internal CountTranslator() 2007 : base("Count", SequenceMethod.Count, SequenceMethod.CountPredicate) 2008 { 2009 } 2010 } 2011 private sealed class LongCountTranslator : CountTranslatorBase 2012 { LongCountTranslator()2013 internal LongCountTranslator() 2014 : base("BigCount", SequenceMethod.LongCount, SequenceMethod.LongCountPredicate) 2015 { 2016 } 2017 } 2018 private abstract class UnarySequenceMethodTranslator : SequenceMethodTranslator 2019 { UnarySequenceMethodTranslator(params SequenceMethod[] methods)2020 protected UnarySequenceMethodTranslator(params SequenceMethod[] methods) : base(methods) { } Translate(ExpressionConverter parent, MethodCallExpression call)2021 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 2022 { 2023 if (null != call.Object) 2024 { 2025 // instance method 2026 Debug.Assert(0 <= call.Arguments.Count); 2027 CqtExpression operand = parent.TranslateSet(call.Object); 2028 return TranslateUnary(parent, operand, call); 2029 } 2030 else 2031 { 2032 // static extension method 2033 Debug.Assert(1 <= call.Arguments.Count); 2034 CqtExpression operand = parent.TranslateSet(call.Arguments[0]); 2035 return TranslateUnary(parent, operand, call); 2036 } 2037 } TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call)2038 protected abstract CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call); 2039 } 2040 private sealed class PassthroughTranslator : UnarySequenceMethodTranslator 2041 { PassthroughTranslator()2042 internal PassthroughTranslator() : base(SequenceMethod.AsQueryableGeneric, SequenceMethod.AsQueryable, SequenceMethod.AsEnumerable) { } TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call)2043 protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call) 2044 { 2045 // make sure the operand has collection type to avoid treating (for instance) String as a 2046 // sub-query 2047 if (TypeSemantics.IsCollectionType(operand.ResultType)) 2048 { 2049 return operand; 2050 } 2051 else 2052 { 2053 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedPassthrough( 2054 call.Method.Name, operand.ResultType.EdmType.Name)); 2055 } 2056 } 2057 } 2058 private sealed class OfTypeTranslator : UnarySequenceMethodTranslator 2059 { OfTypeTranslator()2060 internal OfTypeTranslator() : base(SequenceMethod.OfType) { } TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call)2061 protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, 2062 MethodCallExpression call) 2063 { 2064 Type clrType = call.Method.GetGenericArguments()[0]; 2065 TypeUsage modelType; 2066 2067 // If the model type does not exist in the perspective or is not either an EntityType 2068 // or a ComplexType, fail - OfType() is not a valid operation on scalars, 2069 // enumerations, collections, etc. 2070 if (!parent.TryGetValueLayerType(clrType, out modelType) || 2071 !(TypeSemantics.IsEntityType(modelType) || TypeSemantics.IsComplexType(modelType))) 2072 { 2073 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_InvalidOfTypeResult(DescribeClrType(clrType))); 2074 } 2075 2076 // Create an of type expression to filter the original query to include 2077 // only those results that are of the specified type. 2078 CqtExpression ofTypeExpression = parent.OfType(operand, modelType); 2079 return ofTypeExpression; 2080 } 2081 } 2082 private sealed class DistinctTranslator : UnarySequenceMethodTranslator 2083 { DistinctTranslator()2084 internal DistinctTranslator() : base(SequenceMethod.Distinct) { } TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call)2085 protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, 2086 MethodCallExpression call) 2087 { 2088 return parent.Distinct(operand); 2089 } 2090 } 2091 private sealed class AnyTranslator : UnarySequenceMethodTranslator 2092 { AnyTranslator()2093 internal AnyTranslator() : base(SequenceMethod.Any) { } TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call)2094 protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, 2095 MethodCallExpression call) 2096 { 2097 // "Any" is equivalent to "exists". 2098 return operand.IsEmpty().Not(); 2099 } 2100 } 2101 private abstract class OneLambdaTranslator : SequenceMethodTranslator 2102 { OneLambdaTranslator(params SequenceMethod[] methods)2103 internal OneLambdaTranslator(params SequenceMethod[] methods) : base(methods) { } Translate(ExpressionConverter parent, MethodCallExpression call)2104 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 2105 { 2106 CqtExpression source; 2107 DbExpressionBinding sourceBinding; 2108 CqtExpression lambda; 2109 return Translate(parent, call, out source, out sourceBinding, out lambda); 2110 } 2111 2112 // Helper method for tranlsation Translate(ExpressionConverter parent, MethodCallExpression call, out CqtExpression source, out DbExpressionBinding sourceBinding, out CqtExpression lambda)2113 protected CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call, out CqtExpression source, out DbExpressionBinding sourceBinding, out CqtExpression lambda) 2114 { 2115 Debug.Assert(2 <= call.Arguments.Count); 2116 2117 // translate source 2118 source = parent.TranslateExpression(call.Arguments[0]); 2119 2120 // translate lambda expression 2121 LambdaExpression lambdaExpression = parent.GetLambdaExpression(call, 1); 2122 lambda = parent.TranslateLambda(lambdaExpression, source, out sourceBinding); 2123 return TranslateOneLambda(parent, sourceBinding, lambda); 2124 } 2125 TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)2126 protected abstract CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda); 2127 } 2128 private sealed class AnyPredicateTranslator : OneLambdaTranslator 2129 { AnyPredicateTranslator()2130 internal AnyPredicateTranslator() : base(SequenceMethod.AnyPredicate) { } TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)2131 protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) 2132 { 2133 return sourceBinding.Any(lambda); 2134 } 2135 } 2136 private sealed class AllTranslator : OneLambdaTranslator 2137 { AllTranslator()2138 internal AllTranslator() : base(SequenceMethod.All) { } TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)2139 protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) 2140 { 2141 return sourceBinding.All(lambda); 2142 } 2143 } 2144 private sealed class WhereTranslator : OneLambdaTranslator 2145 { WhereTranslator()2146 internal WhereTranslator() : base(SequenceMethod.Where) { } TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)2147 protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) 2148 { 2149 return parent.Filter(sourceBinding, lambda); 2150 } 2151 } 2152 private sealed class SelectTranslator : OneLambdaTranslator 2153 { SelectTranslator()2154 internal SelectTranslator() : base(SequenceMethod.Select) { } Translate(ExpressionConverter parent, MethodCallExpression call)2155 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 2156 { 2157 CqtExpression source; 2158 DbExpressionBinding sourceBinding; 2159 CqtExpression lambda; 2160 CqtExpression result = Translate(parent, call, out source, out sourceBinding, out lambda); 2161 return result; 2162 } 2163 TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)2164 protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) 2165 { 2166 return parent.Project(sourceBinding, lambda); 2167 } 2168 } 2169 private sealed class DefaultIfEmptyTranslator : SequenceMethodTranslator 2170 { DefaultIfEmptyTranslator()2171 internal DefaultIfEmptyTranslator() 2172 : base(SequenceMethod.DefaultIfEmpty, SequenceMethod.DefaultIfEmptyValue) 2173 { 2174 } Translate(ExpressionConverter parent, MethodCallExpression call)2175 internal override DbExpression Translate(ExpressionConverter parent, MethodCallExpression call) 2176 { 2177 DbExpression operand = parent.TranslateSet(call.Arguments[0]); 2178 2179 // get default value (different translation for non-null defaults) 2180 DbExpression defaultValue = call.Arguments.Count == 2 ? 2181 parent.TranslateExpression(call.Arguments[1]) : 2182 GetDefaultValue(parent, call.Type); 2183 2184 DbExpression left = DbExpressionBuilder.NewCollection(new DbExpression[] { 1 }); 2185 DbExpressionBinding leftBinding = left.BindAs(parent.AliasGenerator.Next()); 2186 2187 // DefaultIfEmpty(value) syntax we may require a sentinel flag to indicate default value substitution 2188 bool requireSentinel = !(null == defaultValue || defaultValue.ExpressionKind == DbExpressionKind.Null); 2189 if (requireSentinel) 2190 { 2191 DbExpressionBinding o = operand.BindAs(parent.AliasGenerator.Next()); 2192 operand = o.Project(new Row(((DbExpression)1).As("sentinel"), o.Variable.As("value"))); 2193 } 2194 2195 DbExpressionBinding rightBinding = operand.BindAs(parent.AliasGenerator.Next()); 2196 DbExpression join = DbExpressionBuilder.LeftOuterJoin(leftBinding, rightBinding, true); 2197 DbExpressionBinding joinBinding = join.BindAs(parent.AliasGenerator.Next()); 2198 DbExpression projection = joinBinding.Variable.Property(rightBinding.VariableName); 2199 2200 // Use a case statement on the sentinel flag to drop the default value in where required 2201 if (requireSentinel) 2202 { 2203 projection = DbExpressionBuilder.Case(new[] { projection.Property("sentinel").IsNull() }, new[] { defaultValue }, projection.Property("value")); 2204 } 2205 2206 DbExpression spannedProjection = joinBinding.Project(projection); 2207 parent.ApplySpanMapping(operand, spannedProjection); 2208 return spannedProjection; 2209 } 2210 GetDefaultValue(ExpressionConverter parent, Type resultType)2211 private static DbExpression GetDefaultValue(ExpressionConverter parent, Type resultType) 2212 { 2213 Type elementType = TypeSystem.GetElementType(resultType); 2214 object defaultValue = TypeSystem.GetDefaultValue(elementType); 2215 DbExpression result = null == defaultValue ? 2216 null : 2217 parent.TranslateExpression(Expression.Constant(defaultValue, elementType)); 2218 return result; 2219 } 2220 } 2221 private sealed class ContainsTranslator : SequenceMethodTranslator 2222 { ContainsTranslator()2223 internal ContainsTranslator() 2224 : base(SequenceMethod.Contains) 2225 { 2226 } Translate(ExpressionConverter parent, MethodCallExpression call)2227 internal override DbExpression Translate(ExpressionConverter parent, MethodCallExpression call) 2228 { 2229 return TranslateContains(parent, call.Arguments[0], call.Arguments[1]); 2230 } TranslateContainsHelper(ExpressionConverter parent, CqtExpression left, IEnumerable<DbExpression> rightList, EqualsPattern pattern, Type leftType, Type rightType)2231 private static DbExpression TranslateContainsHelper(ExpressionConverter parent, CqtExpression left, IEnumerable<DbExpression> rightList, EqualsPattern pattern, Type leftType, Type rightType) 2232 { 2233 var predicates = rightList. 2234 Select(argument => parent.CreateEqualsExpression(left, argument, pattern, leftType, rightType)); 2235 var expressions = new List<DbExpression>(predicates); 2236 var cqt = System.Data.Common.Utils.Helpers.BuildBalancedTreeInPlace(expressions, 2237 (prev, next) => prev.Or(next) 2238 ); 2239 return cqt; 2240 } TranslateContains(ExpressionConverter parent, Expression sourceExpression, Expression valueExpression)2241 internal static DbExpression TranslateContains(ExpressionConverter parent, Expression sourceExpression, Expression valueExpression) 2242 { 2243 DbExpression source = parent.NormalizeSetSource(parent.TranslateExpression(sourceExpression)); 2244 DbExpression value = parent.TranslateExpression(valueExpression); 2245 Type sourceArgumentType = TypeSystem.GetElementType(sourceExpression.Type); 2246 2247 if (source.ExpressionKind == DbExpressionKind.NewInstance) 2248 { 2249 IList<DbExpression> arguments = ((DbNewInstanceExpression)source).Arguments; 2250 if (arguments.Count > 0) 2251 { 2252 if (!parent._funcletizer.RootContext.ContextOptions.UseCSharpNullComparisonBehavior) 2253 { 2254 return TranslateContainsHelper(parent, value, arguments, EqualsPattern.Store, sourceArgumentType, valueExpression.Type); 2255 } 2256 // Replaces this => (tbl.Col = 1 AND tbl.Col IS NOT NULL) OR (tbl.Col = 2 AND tbl.Col IS NOT NULL) OR ... 2257 // with this => (tbl.Col = 1 OR tbl.Col = 2 OR ...) AND (tbl.Col IS NOT NULL)) 2258 // which in turn gets simplified to this => (tbl.Col IN (1, 2, ...) AND (tbl.Col IS NOT NULL)) in SqlGenerator 2259 IEnumerable<DbExpression> constantArguments = arguments.Where(argument => argument.ExpressionKind == DbExpressionKind.Constant); 2260 CqtExpression constantCqt = null; 2261 if (constantArguments.Count() > 0) 2262 { 2263 constantCqt = TranslateContainsHelper(parent, value, constantArguments, EqualsPattern.PositiveNullEqualityNonComposable, sourceArgumentType, valueExpression.Type); 2264 constantCqt = constantCqt.And(value.IsNull().Not()); 2265 } 2266 // Does not optimize conversion of variables embedded in the list. 2267 IEnumerable<DbExpression> otherArguments = arguments.Where(argument => argument.ExpressionKind != DbExpressionKind.Constant); 2268 CqtExpression otherCqt = null; 2269 if (otherArguments.Count() > 0) 2270 { 2271 otherCqt = TranslateContainsHelper(parent, value, otherArguments, EqualsPattern.PositiveNullEqualityComposable, sourceArgumentType, valueExpression.Type); 2272 } 2273 if (constantCqt == null) return otherCqt; 2274 if (otherCqt == null) return constantCqt; 2275 return constantCqt.Or(otherCqt); 2276 } 2277 return false; 2278 } 2279 2280 DbExpressionBinding sourceBinding = source.BindAs(parent.AliasGenerator.Next()); 2281 EqualsPattern pattern = EqualsPattern.Store; 2282 if (parent._funcletizer.RootContext.ContextOptions.UseCSharpNullComparisonBehavior) 2283 { 2284 pattern = EqualsPattern.PositiveNullEqualityComposable; 2285 } 2286 return sourceBinding.Filter(parent.CreateEqualsExpression(sourceBinding.Variable, value, pattern, sourceArgumentType, valueExpression.Type)).Exists(); 2287 } 2288 } 2289 private abstract class FirstTranslatorBase : UnarySequenceMethodTranslator 2290 { FirstTranslatorBase(params SequenceMethod[] methods)2291 protected FirstTranslatorBase(params SequenceMethod[] methods) : base(methods) { } 2292 LimitResult(ExpressionConverter parent, CqtExpression expression)2293 protected virtual CqtExpression LimitResult(ExpressionConverter parent, CqtExpression expression) 2294 { 2295 // Only need the first result. 2296 return parent.Limit(expression, DbExpressionBuilder.Constant(1)); 2297 } TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call)2298 protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call) 2299 { 2300 CqtExpression result = LimitResult(parent, operand); 2301 2302 // If this FirstOrDefault/SingleOrDefault() operation is the root of the query, 2303 // then the evaluation is performed in the client over the resulting set, 2304 // to provide the same semantics as Linq to Objects. Otherwise, an Element 2305 // expression is applied to retrieve the single element (or null, if empty) 2306 // from the output set. 2307 if (!parent.IsQueryRoot(call)) 2308 { 2309 result = result.Element(); 2310 result = AddDefaultCase(parent, result, call.Type); 2311 } 2312 2313 // Span is preserved over First/FirstOrDefault with or without a predicate 2314 Span inputSpan = null; 2315 if (parent.TryGetSpan(operand, out inputSpan)) 2316 { 2317 parent.AddSpanMapping(result, inputSpan); 2318 } 2319 2320 return result; 2321 } AddDefaultCase(ExpressionConverter parent, CqtExpression element, Type elementType)2322 internal static CqtExpression AddDefaultCase(ExpressionConverter parent, CqtExpression element, Type elementType) 2323 { 2324 // Retrieve default value. 2325 object defaultValue = TypeSystem.GetDefaultValue(elementType); 2326 if (null == defaultValue) 2327 { 2328 // Already null, which is the implicit default for DbElementExpression 2329 return element; 2330 } 2331 2332 Debug.Assert(TypeSemantics.IsScalarType(element.ResultType), "Primitive or enum type expected at this point."); 2333 2334 // Otherwise, use the default value for the type 2335 List<CqtExpression> whenExpressions = new List<CqtExpression>(1); 2336 2337 whenExpressions.Add(parent.CreateIsNullExpression(element, elementType)); 2338 List<CqtExpression> thenExpressions = new List<CqtExpression>(1); 2339 thenExpressions.Add(DbExpressionBuilder.Constant(element.ResultType, defaultValue)); 2340 DbCaseExpression caseExpression = DbExpressionBuilder.Case(whenExpressions, thenExpressions, element); 2341 return caseExpression; 2342 } 2343 } 2344 private sealed class FirstTranslator : FirstTranslatorBase 2345 { FirstTranslator()2346 internal FirstTranslator() : base(SequenceMethod.First) { } 2347 TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call)2348 protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call) 2349 { 2350 if (!parent.IsQueryRoot(call)) 2351 { 2352 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedNestedFirst); 2353 } 2354 return base.TranslateUnary(parent, operand, call); 2355 } 2356 } 2357 private sealed class FirstOrDefaultTranslator : FirstTranslatorBase 2358 { FirstOrDefaultTranslator()2359 internal FirstOrDefaultTranslator() : base(SequenceMethod.FirstOrDefault) { } 2360 } 2361 private abstract class SingleTranslatorBase : FirstTranslatorBase 2362 { SingleTranslatorBase(params SequenceMethod[] methods)2363 protected SingleTranslatorBase(params SequenceMethod[] methods) : base(methods) { } 2364 TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call)2365 protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call) 2366 { 2367 if (!parent.IsQueryRoot(call)) 2368 { 2369 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedNestedSingle); 2370 } 2371 return base.TranslateUnary(parent, operand, call); 2372 } 2373 LimitResult(ExpressionConverter parent, CqtExpression expression)2374 protected override CqtExpression LimitResult(ExpressionConverter parent, CqtExpression expression) 2375 { 2376 // Only need two results - one to return as the actual result and another so we can throw if there is more than one 2377 return parent.Limit(expression, DbExpressionBuilder.Constant(2)); 2378 } 2379 } 2380 2381 private sealed class SingleTranslator : SingleTranslatorBase 2382 { SingleTranslator()2383 internal SingleTranslator() : base(SequenceMethod.Single) { } 2384 } 2385 2386 private sealed class SingleOrDefaultTranslator : SingleTranslatorBase 2387 { SingleOrDefaultTranslator()2388 internal SingleOrDefaultTranslator() : base(SequenceMethod.SingleOrDefault) { } 2389 } 2390 2391 private abstract class FirstPredicateTranslatorBase : OneLambdaTranslator 2392 { FirstPredicateTranslatorBase(params SequenceMethod[] methods)2393 protected FirstPredicateTranslatorBase(params SequenceMethod[] methods) : base(methods) { } 2394 RestrictResult(ExpressionConverter parent, CqtExpression expression)2395 protected virtual CqtExpression RestrictResult(ExpressionConverter parent, CqtExpression expression) 2396 { 2397 // Only need the first result. 2398 return parent.Limit(expression, DbExpressionBuilder.Constant(1)); 2399 } 2400 Translate(ExpressionConverter parent, MethodCallExpression call)2401 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 2402 { 2403 // Convert the input set and the predicate into a filter expression 2404 CqtExpression input = base.Translate(parent, call); 2405 2406 // If this First/FirstOrDefault/Single/SingleOrDefault is the root of the query, 2407 // then the actual result will be produced by evaluated by 2408 // calling First/Single() or FirstOrDefault() on the filtered input set, 2409 // which is limited to at most one element by applying a limit. 2410 if (parent.IsQueryRoot(call)) 2411 { 2412 // Calling ExpressionConverter.Limit propagates the Span. 2413 return RestrictResult(parent, input); 2414 } 2415 else 2416 { 2417 CqtExpression element = input.Element(); 2418 element = FirstTranslatorBase.AddDefaultCase(parent, element, call.Type); 2419 2420 // Span is preserved over First/FirstOrDefault with or without a predicate 2421 Span inputSpan = null; 2422 if (parent.TryGetSpan(input, out inputSpan)) 2423 { 2424 parent.AddSpanMapping(element, inputSpan); 2425 } 2426 2427 return element; 2428 } 2429 } 2430 TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)2431 protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) 2432 { 2433 return parent.Filter(sourceBinding, lambda); 2434 } 2435 } 2436 2437 private sealed class FirstPredicateTranslator : FirstPredicateTranslatorBase 2438 { FirstPredicateTranslator()2439 internal FirstPredicateTranslator() : base(SequenceMethod.FirstPredicate) { } 2440 Translate(ExpressionConverter parent, MethodCallExpression call)2441 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 2442 { 2443 if (!parent.IsQueryRoot(call)) 2444 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedNestedFirst); 2445 return base.Translate(parent, call); 2446 } 2447 } 2448 2449 private sealed class FirstOrDefaultPredicateTranslator : FirstPredicateTranslatorBase 2450 { FirstOrDefaultPredicateTranslator()2451 internal FirstOrDefaultPredicateTranslator() : base(SequenceMethod.FirstOrDefaultPredicate) { } 2452 } 2453 2454 private abstract class SinglePredicateTranslatorBase : FirstPredicateTranslatorBase 2455 { SinglePredicateTranslatorBase(params SequenceMethod[] methods)2456 protected SinglePredicateTranslatorBase(params SequenceMethod[] methods) : base(methods) { } 2457 Translate(ExpressionConverter parent, MethodCallExpression call)2458 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 2459 { 2460 if (!parent.IsQueryRoot(call)) 2461 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedNestedSingle); 2462 return base.Translate(parent, call); 2463 } 2464 RestrictResult(ExpressionConverter parent, CqtExpression expression)2465 protected override CqtExpression RestrictResult(ExpressionConverter parent, CqtExpression expression) 2466 { 2467 // Only need two results - one to return and another to see if it wasn't alone to throw. 2468 return parent.Limit(expression, DbExpressionBuilder.Constant(2)); 2469 } 2470 } 2471 2472 private sealed class SinglePredicateTranslator : SinglePredicateTranslatorBase 2473 { SinglePredicateTranslator()2474 internal SinglePredicateTranslator() : base(SequenceMethod.SinglePredicate) { } 2475 } 2476 2477 private sealed class SingleOrDefaultPredicateTranslator : SinglePredicateTranslatorBase 2478 { SingleOrDefaultPredicateTranslator()2479 internal SingleOrDefaultPredicateTranslator() : base(SequenceMethod.SingleOrDefaultPredicate) { } 2480 } 2481 2482 private sealed class SelectManyTranslator : OneLambdaTranslator 2483 { SelectManyTranslator()2484 internal SelectManyTranslator() : base(SequenceMethod.SelectMany, SequenceMethod.SelectManyResultSelector) { } Translate(ExpressionConverter parent, MethodCallExpression call)2485 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 2486 { 2487 // perform a cross apply to implement the core logic for SelectMany (this translates the collection selector): 2488 // SelectMany(i, Func<i, IEnum<o>> collectionSelector) => 2489 // i CROSS APPLY collectionSelector(i) 2490 // The cross-apply yields a collection <left, right> from which we yield either the right hand side (when 2491 // no explicit resultSelector is given) or over which we apply the resultSelector Lambda expression. 2492 2493 LambdaExpression resultSelector = (call.Arguments.Count == 3) ? parent.GetLambdaExpression(call, 2) : null; 2494 2495 CqtExpression apply = base.Translate(parent, call); 2496 2497 // try detecting the linq pattern for a left outer join and produce a simpler c-tree for it. 2498 DbExpressionBinding applyInput; 2499 EdmProperty lojRightInput; 2500 bool isLeftOuterJoin = IsLeftOuterJoin(apply, out applyInput, out lojRightInput); 2501 if (isLeftOuterJoin) 2502 { 2503 // 1) 2504 // if apply looks like a cross apply with right input being a loj of {1} to a collection from the apply's left input: 2505 // ( 2506 // select o, (select ...) as lojRightInput 2507 // from (...) as o 2508 // ) as x 2509 // CROSS apply 2510 // ( 2511 // select loj 2512 // from {1} left outer join x.lojRightInput as loj on true 2513 // ) as y 2514 // then rewrite it as outer apply 2515 // ( 2516 // select o, (select ...) as lojRightInput 2517 // from (...) as o 2518 // ) as x 2519 // OUTER apply 2520 // x.lojRightInput as loj 2521 // 2522 // 2) 2523 // if there is a trivial resultSelector that would produce something like this: 2524 // select x as m, loj as n 2525 // from (...) as x outer apply (...) as loj 2526 // then rewrite it as 2527 // (...) as m outer apply (...) as n 2528 string outerBindingName; 2529 string innerBindingName; 2530 InitializerMetadata initializerMetadata; 2531 if (resultSelector != null && IsTrivialRename(resultSelector, parent, out outerBindingName, out innerBindingName, out initializerMetadata)) 2532 { 2533 // It is #1 and #2 as described above: 2534 // - produce the outer apply 2535 // - name inputs as specified in the resultSelector 2536 // - return the apply. 2537 var newInput = applyInput.Expression.BindAs(outerBindingName); 2538 var newApply = newInput.Variable.Property(lojRightInput.Name).BindAs(innerBindingName); 2539 2540 var resultType = TypeUsage.Create(TypeHelpers.CreateRowType( 2541 new List<KeyValuePair<string, TypeUsage>>() 2542 { 2543 new KeyValuePair<string, TypeUsage>(newInput.VariableName, newInput.VariableType), 2544 new KeyValuePair<string, TypeUsage>(newApply.VariableName, newApply.VariableType) 2545 }, 2546 initializerMetadata)); 2547 2548 return new DbApplyExpression(DbExpressionKind.OuterApply, TypeUsage.Create(TypeHelpers.CreateCollectionType(resultType)), newInput, newApply); 2549 } 2550 else 2551 { 2552 // It is just #1 as described above, 2553 // so produce the outer apply and let the logic below generate projection using the resultSelector. 2554 apply = applyInput.OuterApply(applyInput.Variable.Property(lojRightInput).BindAs(parent.AliasGenerator.Next())); 2555 } 2556 } 2557 2558 DbExpressionBinding applyBinding = apply.BindAs(parent.AliasGenerator.Next()); 2559 RowType applyRowType = (RowType)(applyBinding.Variable.ResultType.EdmType); 2560 CqtExpression projectRight = applyBinding.Variable.Property(applyRowType.Properties[1]); 2561 2562 CqtExpression resultProjection; 2563 if (resultSelector != null) 2564 { 2565 CqtExpression projectLeft = applyBinding.Variable.Property(applyRowType.Properties[0]); 2566 2567 // add the left and right projection terms to the binding context 2568 parent._bindingContext.PushBindingScope(new Binding(resultSelector.Parameters[0], projectLeft)); 2569 parent._bindingContext.PushBindingScope(new Binding(resultSelector.Parameters[1], projectRight)); 2570 2571 // translate the result selector 2572 resultProjection = parent.TranslateSet(resultSelector.Body); 2573 2574 // pop binding context 2575 parent._bindingContext.PopBindingScope(); 2576 parent._bindingContext.PopBindingScope(); 2577 } 2578 else 2579 { 2580 // project out the right hand side of the apply 2581 resultProjection = projectRight; 2582 } 2583 2584 // wrap result projection in project expression 2585 return applyBinding.Project(resultProjection); 2586 } IsLeftOuterJoin(CqtExpression cqtExpression, out DbExpressionBinding crossApplyInput, out EdmProperty lojRightInput)2587 private static bool IsLeftOuterJoin(CqtExpression cqtExpression, out DbExpressionBinding crossApplyInput, out EdmProperty lojRightInput) 2588 { 2589 // Check cqtExpression to see if looks like this: 2590 // 2591 // ( 2592 // select o, (select ...) as lojRightInput 2593 // from (...) as o 2594 // ) as x 2595 // cross apply 2596 // ( 2597 // select loj 2598 // from {1} left outer join x.lojRightInput as loj on true 2599 // ) as y 2600 // 2601 // If yes - return true, 2602 // crossApplyInput = ( 2603 // select o, (select ...) as lojRightInput 2604 // from (...) as o 2605 // ) as x 2606 // lojRightInput = x.lojRightInput 2607 2608 crossApplyInput = null; 2609 lojRightInput = null; 2610 2611 if (cqtExpression.ExpressionKind != DbExpressionKind.CrossApply) 2612 { 2613 return false; 2614 } 2615 var crossApply = (DbApplyExpression)cqtExpression; 2616 2617 if (crossApply.Input.VariableType.EdmType.BuiltInTypeKind != BuiltInTypeKind.RowType) 2618 { 2619 return false; 2620 } 2621 var crossApplyInputRowType = (RowType)crossApply.Input.VariableType.EdmType; 2622 2623 // rightProject = (select loj 2624 // from {1} left outer join x.lojRightInput as loj on true) 2625 if (crossApply.Apply.Expression.ExpressionKind != DbExpressionKind.Project) 2626 { 2627 return false; 2628 } 2629 var rightProject = (DbProjectExpression)crossApply.Apply.Expression; 2630 2631 // loj = {1} left outer join x.lojRightInput as loj on true 2632 if (rightProject.Input.Expression.ExpressionKind != DbExpressionKind.LeftOuterJoin) 2633 { 2634 return false; 2635 } 2636 var loj = (DbJoinExpression)rightProject.Input.Expression; 2637 2638 if (rightProject.Projection.ExpressionKind != DbExpressionKind.Property) 2639 { 2640 return false; 2641 } 2642 var rightProjectProjection = (DbPropertyExpression)rightProject.Projection; 2643 2644 // make sure that in 2645 // rightProject = (select loj 2646 // from {1} left outer join x.lojRightInput as loj on true) 2647 // loj comes from the right side of the left outer join. 2648 if (rightProjectProjection.Instance != rightProject.Input.Variable || 2649 rightProjectProjection.Property.Name != loj.Right.VariableName || 2650 loj.JoinCondition.ExpressionKind != DbExpressionKind.Constant) 2651 { 2652 return false; 2653 } 2654 var lojCondition = (DbConstantExpression)loj.JoinCondition; 2655 2656 // make sure that in 2657 // rightProject = (select loj 2658 // from {1} left outer join x.lojRightInput as loj on true) 2659 // the left outer join condition is "true". 2660 if (!(lojCondition.Value is bool) || (bool)lojCondition.Value != true) 2661 { 2662 return false; 2663 } 2664 2665 // make sure that in 2666 // rightProject = (select loj 2667 // from {1} left outer join x.lojRightInput as loj on true) 2668 // the left input into the left outer join condition is a single-element collection "{some constant}" 2669 if (loj.Left.Expression.ExpressionKind != DbExpressionKind.NewInstance) 2670 { 2671 return false; 2672 } 2673 var lojLeft = (DbNewInstanceExpression)loj.Left.Expression; 2674 if (lojLeft.Arguments.Count != 1 || lojLeft.Arguments[0].ExpressionKind != DbExpressionKind.Constant) 2675 { 2676 return false; 2677 } 2678 2679 // make sure that in 2680 // rightProject = (select loj 2681 // from {1} left outer join x.lojRightInput as loj on true) 2682 // the x.lojRightInput comes from the left side of the cross apply 2683 if (loj.Right.Expression.ExpressionKind != DbExpressionKind.Property) 2684 { 2685 return false; 2686 } 2687 var lojRight = (DbPropertyExpression)loj.Right.Expression; 2688 if (lojRight.Instance != crossApply.Input.Variable) 2689 { 2690 return false; 2691 } 2692 var lojRightValueSource = crossApplyInputRowType.Properties.SingleOrDefault(p => p.Name == lojRight.Property.Name); 2693 if (lojRightValueSource == null) 2694 { 2695 return false; 2696 } 2697 2698 crossApplyInput = crossApply.Input; 2699 lojRightInput = lojRightValueSource; 2700 2701 return true; 2702 } TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)2703 protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) 2704 { 2705 // elements of the inner selector should be used 2706 lambda = parent.NormalizeSetSource(lambda); 2707 DbExpressionBinding applyBinding = lambda.BindAs(parent.AliasGenerator.Next()); 2708 DbApplyExpression crossApply = sourceBinding.CrossApply(applyBinding); 2709 return crossApply; 2710 } 2711 } 2712 private sealed class CastMethodTranslator : SequenceMethodTranslator 2713 { CastMethodTranslator()2714 internal CastMethodTranslator() : base(SequenceMethod.Cast) { } Translate(ExpressionConverter parent, MethodCallExpression call)2715 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 2716 { 2717 // Translate source 2718 CqtExpression source = parent.TranslateSet(call.Arguments[0]); 2719 2720 // Figure out the type to cast to 2721 Type toClrType = TypeSystem.GetElementType(call.Type); 2722 Type fromClrType = TypeSystem.GetElementType(call.Arguments[0].Type); 2723 2724 // Get binding to the elements of the input source 2725 DbExpressionBinding binding = source.BindAs(parent.AliasGenerator.Next()); 2726 2727 CqtExpression cast = parent.CreateCastExpression(binding.Variable, toClrType, fromClrType); 2728 return parent.Project(binding, cast); 2729 } 2730 } 2731 2732 private sealed class GroupByTranslator : SequenceMethodTranslator 2733 { GroupByTranslator()2734 internal GroupByTranslator() 2735 : base(SequenceMethod.GroupBy, SequenceMethod.GroupByElementSelector, SequenceMethod.GroupByElementSelectorResultSelector, 2736 SequenceMethod.GroupByResultSelector) 2737 { 2738 } 2739 2740 // Creates a Cqt GroupByExpression with a group aggregate Translate(ExpressionConverter parent, MethodCallExpression call, SequenceMethod sequenceMethod)2741 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call, SequenceMethod sequenceMethod) 2742 { 2743 // translate source 2744 CqtExpression source = parent.TranslateSet(call.Arguments[0]); 2745 2746 // translate key selector 2747 LambdaExpression keySelectorLinq = parent.GetLambdaExpression(call, 1); 2748 DbGroupExpressionBinding sourceGroupBinding; 2749 CqtExpression keySelector = parent.TranslateLambda(keySelectorLinq, source, out sourceGroupBinding); 2750 2751 // create distinct expression 2752 if (!TypeSemantics.IsEqualComparable(keySelector.ResultType)) 2753 { 2754 // to avoid confusing error message about the "distinct" type, pre-emptively raise an exception 2755 // about the group by key selector 2756 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedKeySelector(call.Method.Name)); 2757 } 2758 2759 List<KeyValuePair<string, DbExpression>> keys = new List<KeyValuePair<string, DbExpression>>(); 2760 List<KeyValuePair<string, DbAggregate>> aggregates = new List<KeyValuePair<string, DbAggregate>>(); 2761 keys.Add(new KeyValuePair<string, CqtExpression>(KeyColumnName, keySelector)); 2762 aggregates.Add(new KeyValuePair<string, DbAggregate>(GroupColumnName, sourceGroupBinding.GroupAggregate)); 2763 2764 DbExpression groupBy = sourceGroupBinding.GroupBy(keys, aggregates); 2765 DbExpressionBinding groupByBinding = groupBy.BindAs(parent.AliasGenerator.Next()); 2766 2767 // interpret element selector if needed 2768 CqtExpression selection = groupByBinding.Variable.Property(GroupColumnName); 2769 2770 bool hasElementSelector = sequenceMethod == SequenceMethod.GroupByElementSelector || 2771 sequenceMethod == SequenceMethod.GroupByElementSelectorResultSelector; 2772 2773 //Create a project over the group by 2774 if (hasElementSelector) 2775 { 2776 LambdaExpression elementSelectorLinq = parent.GetLambdaExpression(call, 2); 2777 DbExpressionBinding elementSelectorSourceBinding; 2778 CqtExpression elementSelector = parent.TranslateLambda(elementSelectorLinq, selection, out elementSelectorSourceBinding); 2779 selection = elementSelectorSourceBinding.Project(elementSelector); 2780 } 2781 2782 // create top level projection <exists, key, group> 2783 CqtExpression[] projectionTerms = new CqtExpression[2]; 2784 projectionTerms[0] = groupByBinding.Variable.Property(KeyColumnName); 2785 projectionTerms[1] = selection; 2786 2787 // build projection type with initializer information 2788 List<EdmProperty> properties = new List<EdmProperty>(2); 2789 properties.Add(new EdmProperty(KeyColumnName, projectionTerms[0].ResultType)); 2790 properties.Add(new EdmProperty(GroupColumnName, projectionTerms[1].ResultType)); 2791 InitializerMetadata initializerMetadata = InitializerMetadata.CreateGroupingInitializer( 2792 parent.EdmItemCollection, TypeSystem.GetElementType(call.Type)); 2793 RowType rowType = new RowType(properties, initializerMetadata); 2794 TypeUsage rowTypeUsage = TypeUsage.Create(rowType); 2795 2796 CqtExpression topLevelProject = groupByBinding.Project(DbExpressionBuilder.New(rowTypeUsage, projectionTerms)); 2797 2798 var result = topLevelProject; 2799 2800 // GroupBy may include a result selector; handle it 2801 result = ProcessResultSelector(parent, call, sequenceMethod, topLevelProject, result); 2802 2803 return result; 2804 } 2805 ProcessResultSelector(ExpressionConverter parent, MethodCallExpression call, SequenceMethod sequenceMethod, CqtExpression topLevelProject, DbExpression result)2806 private static DbExpression ProcessResultSelector(ExpressionConverter parent, MethodCallExpression call, SequenceMethod sequenceMethod, CqtExpression topLevelProject, DbExpression result) 2807 { 2808 // interpret result selector if needed 2809 LambdaExpression resultSelectorLinqExpression = null; 2810 if (sequenceMethod == SequenceMethod.GroupByResultSelector) 2811 { 2812 resultSelectorLinqExpression = parent.GetLambdaExpression(call, 2); 2813 } 2814 else if (sequenceMethod == SequenceMethod.GroupByElementSelectorResultSelector) 2815 { 2816 resultSelectorLinqExpression = parent.GetLambdaExpression(call, 3); 2817 } 2818 if (null != resultSelectorLinqExpression) 2819 { 2820 // selector maps (Key, Group) -> Result 2821 // push bindings for key and group 2822 DbExpressionBinding topLevelProjectBinding = topLevelProject.BindAs(parent.AliasGenerator.Next()); 2823 DbPropertyExpression keyExpression = topLevelProjectBinding.Variable.Property(KeyColumnName); 2824 DbPropertyExpression groupExpression = topLevelProjectBinding.Variable.Property(GroupColumnName); 2825 parent._bindingContext.PushBindingScope(new Binding(resultSelectorLinqExpression.Parameters[0], keyExpression)); 2826 parent._bindingContext.PushBindingScope(new Binding(resultSelectorLinqExpression.Parameters[1], groupExpression)); 2827 2828 // translate selector 2829 CqtExpression resultSelector = parent.TranslateExpression( 2830 resultSelectorLinqExpression.Body); 2831 result = topLevelProjectBinding.Project(resultSelector); 2832 2833 parent._bindingContext.PopBindingScope(); 2834 parent._bindingContext.PopBindingScope(); 2835 } 2836 return result; 2837 } Translate(ExpressionConverter parent, MethodCallExpression call)2838 internal override DbExpression Translate(ExpressionConverter parent, MethodCallExpression call) 2839 { 2840 Debug.Fail("unreachable code"); 2841 return null; 2842 } 2843 } 2844 private sealed class GroupJoinTranslator : SequenceMethodTranslator 2845 { GroupJoinTranslator()2846 internal GroupJoinTranslator() 2847 : base(SequenceMethod.GroupJoin) 2848 { 2849 } Translate(ExpressionConverter parent, MethodCallExpression call)2850 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 2851 { 2852 // o.GroupJoin(i, ok => outerKeySelector, ik => innerKeySelector, (o, i) => projection) 2853 // --> 2854 // SELECT projection(o, i) 2855 // FROM ( 2856 // SELECT o, (SELECT i FROM i WHERE o.outerKeySelector = i.innerKeySelector) as i 2857 // FROM o) 2858 2859 // translate inputs 2860 CqtExpression outer = parent.TranslateSet(call.Arguments[0]); 2861 CqtExpression inner = parent.TranslateSet(call.Arguments[1]); 2862 2863 // translate key selectors 2864 DbExpressionBinding outerBinding; 2865 DbExpressionBinding innerBinding; 2866 LambdaExpression outerLambda = parent.GetLambdaExpression(call, 2); 2867 LambdaExpression innerLambda = parent.GetLambdaExpression(call, 3); 2868 CqtExpression outerSelector = parent.TranslateLambda( 2869 outerLambda, outer, out outerBinding); 2870 CqtExpression innerSelector = parent.TranslateLambda( 2871 innerLambda, inner, out innerBinding); 2872 2873 // create innermost SELECT i FROM i WHERE ... 2874 if (!TypeSemantics.IsEqualComparable(outerSelector.ResultType) || 2875 !TypeSemantics.IsEqualComparable(innerSelector.ResultType)) 2876 { 2877 throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedKeySelector(call.Method.Name)); 2878 } 2879 CqtExpression nestedCollection = parent.Filter(innerBinding, 2880 parent.CreateEqualsExpression(outerSelector, innerSelector, EqualsPattern.PositiveNullEqualityNonComposable, outerLambda.Body.Type, innerLambda.Body.Type)); 2881 2882 // create "join" SELECT o, (nestedCollection) 2883 const string outerColumn = "o"; 2884 const string innerColumn = "i"; 2885 List<KeyValuePair<string, CqtExpression>> recordColumns = new List<KeyValuePair<string, CqtExpression>>(2); 2886 recordColumns.Add(new KeyValuePair<string, CqtExpression>(outerColumn, outerBinding.Variable)); 2887 recordColumns.Add(new KeyValuePair<string, CqtExpression>(innerColumn, nestedCollection)); 2888 CqtExpression joinProjection = DbExpressionBuilder.NewRow(recordColumns); 2889 CqtExpression joinProject = outerBinding.Project(joinProjection); 2890 DbExpressionBinding joinProjectBinding = joinProject.BindAs(parent.AliasGenerator.Next()); 2891 2892 // create property expressions for the outer and inner terms to bind to the parameters to the 2893 // group join selector 2894 CqtExpression outerProperty = joinProjectBinding.Variable.Property(outerColumn); 2895 CqtExpression innerProperty = joinProjectBinding.Variable.Property(innerColumn); 2896 2897 // push the inner and the outer terms into the binding scope 2898 LambdaExpression linqSelector = parent.GetLambdaExpression(call, 4); 2899 parent._bindingContext.PushBindingScope(new Binding(linqSelector.Parameters[0], outerProperty)); 2900 parent._bindingContext.PushBindingScope(new Binding(linqSelector.Parameters[1], innerProperty)); 2901 2902 // translate the selector 2903 CqtExpression selectorProject = parent.TranslateExpression(linqSelector.Body); 2904 2905 // pop the binding scope 2906 parent._bindingContext.PopBindingScope(); 2907 parent._bindingContext.PopBindingScope(); 2908 2909 // create the selector projection 2910 CqtExpression selector = joinProjectBinding.Project(selectorProject); 2911 2912 selector = CollapseTrivialRenamingProjection(selector); 2913 2914 return selector; 2915 } CollapseTrivialRenamingProjection(CqtExpression cqtExpression)2916 private CqtExpression CollapseTrivialRenamingProjection(CqtExpression cqtExpression) 2917 { 2918 // Detect "select inner.x as m, inner.y as n 2919 // from (select ... as x, ... as y from ...) as inner" 2920 // and convert to "select ... as m, ... as n from ..." 2921 2922 if (cqtExpression.ExpressionKind != DbExpressionKind.Project) 2923 { 2924 return cqtExpression; 2925 } 2926 var project = (DbProjectExpression)cqtExpression; 2927 2928 if (project.Projection.ExpressionKind != DbExpressionKind.NewInstance || 2929 project.Projection.ResultType.EdmType.BuiltInTypeKind != BuiltInTypeKind.RowType) 2930 { 2931 return cqtExpression; 2932 } 2933 var projection = (DbNewInstanceExpression)project.Projection; 2934 var outerRowType = (RowType)projection.ResultType.EdmType; 2935 2936 var renames = new List<Tuple<EdmProperty, string>>(); 2937 for (int i = 0; i < projection.Arguments.Count; ++i) 2938 { 2939 if (projection.Arguments[i].ExpressionKind != DbExpressionKind.Property) 2940 { 2941 return cqtExpression; 2942 } 2943 var rename = (DbPropertyExpression)projection.Arguments[i]; 2944 2945 if (rename.Instance != project.Input.Variable) 2946 { 2947 return cqtExpression; 2948 } 2949 renames.Add(Tuple.Create((EdmProperty)rename.Property, outerRowType.Properties[i].Name)); 2950 } 2951 2952 if (project.Input.Expression.ExpressionKind != DbExpressionKind.Project) 2953 { 2954 return cqtExpression; 2955 } 2956 var innerProject = (DbProjectExpression)project.Input.Expression; 2957 2958 if (innerProject.Projection.ExpressionKind != DbExpressionKind.NewInstance || 2959 innerProject.Projection.ResultType.EdmType.BuiltInTypeKind != BuiltInTypeKind.RowType) 2960 { 2961 return cqtExpression; 2962 } 2963 var innerProjection = (DbNewInstanceExpression)innerProject.Projection; 2964 var innerRowType = (RowType)innerProjection.ResultType.EdmType; 2965 2966 var newProjectionArguments = new List<CqtExpression>(); 2967 foreach (var rename in renames) 2968 { 2969 var innerPropertyIndex = innerRowType.Properties.IndexOf(rename.Item1); 2970 newProjectionArguments.Add(innerProjection.Arguments[innerPropertyIndex]); 2971 } 2972 2973 var newProjection = projection.ResultType.New(newProjectionArguments); 2974 return innerProject.Input.Project(newProjection); 2975 } 2976 } 2977 private abstract class OrderByTranslatorBase : OneLambdaTranslator 2978 { 2979 private readonly bool _ascending; OrderByTranslatorBase(bool ascending, params SequenceMethod[] methods)2980 protected OrderByTranslatorBase(bool ascending, params SequenceMethod[] methods) 2981 : base(methods) 2982 { 2983 _ascending = ascending; 2984 } TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)2985 protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) 2986 { 2987 List<DbSortClause> keys = new List<DbSortClause>(1); 2988 DbSortClause sortSpec = (_ascending ? lambda.ToSortClause() : lambda.ToSortClauseDescending()); 2989 keys.Add(sortSpec); 2990 DbSortExpression sort = parent.Sort(sourceBinding, keys); 2991 return sort; 2992 } 2993 } 2994 private sealed class OrderByTranslator : OrderByTranslatorBase 2995 { OrderByTranslator()2996 internal OrderByTranslator() : base(true, SequenceMethod.OrderBy) { } 2997 } 2998 private sealed class OrderByDescendingTranslator : OrderByTranslatorBase 2999 { OrderByDescendingTranslator()3000 internal OrderByDescendingTranslator() : base(false, SequenceMethod.OrderByDescending) { } 3001 } 3002 // Note: because we need to "push-down" the expression binding for ThenBy, this class 3003 // does not inherit from OneLambdaTranslator, although it is similar. 3004 private abstract class ThenByTranslatorBase : SequenceMethodTranslator 3005 { 3006 private readonly bool _ascending; ThenByTranslatorBase(bool ascending, params SequenceMethod[] methods)3007 protected ThenByTranslatorBase(bool ascending, params SequenceMethod[] methods) 3008 : base(methods) 3009 { 3010 _ascending = ascending; 3011 } Translate(ExpressionConverter parent, MethodCallExpression call)3012 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) 3013 { 3014 Debug.Assert(2 == call.Arguments.Count); 3015 CqtExpression source = parent.TranslateSet(call.Arguments[0]); 3016 if (DbExpressionKind.Sort != source.ExpressionKind) 3017 { 3018 throw EntityUtil.InvalidOperation(System.Data.Entity.Strings.ELinq_ThenByDoesNotFollowOrderBy); 3019 } 3020 DbSortExpression sortExpression = (DbSortExpression)source; 3021 3022 // retrieve information about existing sort 3023 DbExpressionBinding binding = sortExpression.Input; 3024 3025 // get information on new sort term 3026 LambdaExpression lambdaExpression = parent.GetLambdaExpression(call, 1); 3027 ParameterExpression parameter = lambdaExpression.Parameters[0]; 3028 3029 // push-down the binding scope information and translate the new sort key 3030 parent._bindingContext.PushBindingScope(new Binding(parameter, binding.Variable)); 3031 CqtExpression lambda = parent.TranslateExpression(lambdaExpression.Body); 3032 parent._bindingContext.PopBindingScope(); 3033 3034 // create a new sort expression 3035 List<DbSortClause> keys = new List<DbSortClause>(sortExpression.SortOrder); 3036 keys.Add(new DbSortClause(lambda, _ascending, null)); 3037 sortExpression = parent.Sort(binding, keys); 3038 3039 return sortExpression; 3040 } 3041 } 3042 private sealed class ThenByTranslator : ThenByTranslatorBase 3043 { ThenByTranslator()3044 internal ThenByTranslator() : base(true, SequenceMethod.ThenBy) { } 3045 } 3046 private sealed class ThenByDescendingTranslator : ThenByTranslatorBase 3047 { ThenByDescendingTranslator()3048 internal ThenByDescendingTranslator() : base(false, SequenceMethod.ThenByDescending) { } 3049 } 3050 #endregion 3051 } 3052 } 3053 } 3054