1 #region MIT license 2 // 3 // MIT license 4 // 5 // Copyright (c) 2007-2008 Jiri Moudry, Pascal Craponne 6 // 7 // Permission is hereby granted, free of charge, to any person obtaining a copy 8 // of this software and associated documentation files (the "Software"), to deal 9 // in the Software without restriction, including without limitation the rights 10 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 // copies of the Software, and to permit persons to whom the Software is 12 // furnished to do so, subject to the following conditions: 13 // 14 // The above copyright notice and this permission notice shall be included in 15 // all copies or substantial portions of the Software. 16 // 17 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 // THE SOFTWARE. 24 // 25 #endregion 26 using System; 27 using System.CodeDom; 28 using System.CodeDom.Compiler; 29 using System.Collections.Generic; 30 using System.ComponentModel; 31 using System.Data; 32 using System.Data.Linq.Mapping; 33 using System.IO; 34 using System.Linq; 35 using System.Reflection; 36 using System.Text; 37 using System.Text.RegularExpressions; 38 39 using Microsoft.CSharp; 40 using Microsoft.VisualBasic; 41 42 using DbLinq.Schema.Dbml; 43 using DbLinq.Schema.Dbml.Adapter; 44 using DbLinq.Util; 45 46 namespace DbMetal.Generator 47 { 48 #if !MONO_STRICT 49 public 50 #endif 51 class CodeDomGenerator : ICodeGenerator 52 { 53 CodeDomProvider Provider { get; set; } 54 55 // Provided only for Processor.EnumerateCodeGenerators(). DO NOT USE. CodeDomGenerator()56 public CodeDomGenerator() 57 { 58 } 59 CodeDomGenerator(CodeDomProvider provider)60 public CodeDomGenerator(CodeDomProvider provider) 61 { 62 this.Provider = provider; 63 } 64 65 public string LanguageCode { 66 get { return "*"; } 67 } 68 69 public string Extension { 70 get { return "*"; } 71 } 72 CreateFromFileExtension(string extension)73 public static CodeDomGenerator CreateFromFileExtension(string extension) 74 { 75 return CreateFromLanguage(CodeDomProvider.GetLanguageFromExtension(extension)); 76 } 77 CreateFromLanguage(string language)78 public static CodeDomGenerator CreateFromLanguage(string language) 79 { 80 return new CodeDomGenerator(CodeDomProvider.CreateProvider(language)); 81 } 82 Write(TextWriter textWriter, Database dbSchema, GenerationContext context)83 public void Write(TextWriter textWriter, Database dbSchema, GenerationContext context) 84 { 85 Context = context; 86 Provider.CreateGenerator(textWriter).GenerateCodeFromNamespace( 87 GenerateCodeDomModel(dbSchema), textWriter, 88 new CodeGeneratorOptions() { 89 BracingStyle = "C", 90 IndentString = "\t", 91 }); 92 } 93 Warning(string format, params object[] args)94 static void Warning(string format, params object[] args) 95 { 96 Console.Error.Write(Path.GetFileName(Environment.GetCommandLineArgs()[0])); 97 Console.Error.Write(": warning: "); 98 Console.Error.WriteLine(format, args); 99 } 100 CreatePartialMethod(string methodName, params CodeParameterDeclarationExpression[] parameters)101 private CodeTypeMember CreatePartialMethod(string methodName, params CodeParameterDeclarationExpression[] parameters) 102 { 103 string prototype = null; 104 if (Provider is CSharpCodeProvider) 105 { 106 prototype = 107 "\t\tpartial void {0}({1});" + Environment.NewLine + 108 "\t\t"; 109 } 110 else if (Provider is VBCodeProvider) 111 { 112 prototype = 113 "\t\tPartial Private Sub {0}({1})" + Environment.NewLine + 114 "\t\tEnd Sub" + Environment.NewLine + 115 "\t\t"; 116 } 117 118 if (prototype == null) 119 { 120 var method = new CodeMemberMethod() { 121 Name = methodName, 122 }; 123 method.Parameters.AddRange(parameters); 124 return method; 125 } 126 127 var methodDecl = new StringWriter(); 128 var gen = Provider.CreateGenerator(methodDecl); 129 130 bool comma = false; 131 foreach (var p in parameters) 132 { 133 if (comma) 134 methodDecl.Write(", "); 135 comma = true; 136 gen.GenerateCodeFromExpression(p, methodDecl, null); 137 } 138 return new CodeSnippetTypeMember(string.Format(prototype, methodName, methodDecl.ToString())); 139 } 140 141 CodeThisReferenceExpression thisReference = new CodeThisReferenceExpression(); 142 143 protected GenerationContext Context { get; set; } 144 GenerateCodeDomModel(Database database)145 protected virtual CodeNamespace GenerateCodeDomModel(Database database) 146 { 147 CodeNamespace _namespace = new CodeNamespace(Context.Parameters.Namespace ?? database.ContextNamespace); 148 149 _namespace.Imports.Add(new CodeNamespaceImport("System")); 150 _namespace.Imports.Add(new CodeNamespaceImport("System.ComponentModel")); 151 #if MONO_STRICT 152 _namespace.Imports.Add(new CodeNamespaceImport("System.Data")); 153 _namespace.Imports.Add(new CodeNamespaceImport("System.Data.Linq")); 154 _namespace.Imports.Add(new CodeNamespaceImport("System.Data.Linq.Mapping")); 155 #else 156 AddConditionalImports(_namespace.Imports, 157 "System.Data", 158 "MONO_STRICT", 159 new[] { "System.Data.Linq" }, 160 new[] { "DbLinq.Data.Linq", "DbLinq.Vendor" }, 161 "System.Data.Linq.Mapping"); 162 #endif 163 _namespace.Imports.Add(new CodeNamespaceImport("System.Diagnostics")); 164 165 var time = Context.Parameters.GenerateTimestamps ? DateTime.Now.ToString("u") : "[TIMESTAMP]"; 166 var header = new CodeCommentStatement(GenerateCommentBanner(database, time)); 167 _namespace.Comments.Add(header); 168 169 _namespace.Types.Add(GenerateContextClass(database)); 170 #if !MONO_STRICT 171 _namespace.Types.Add(GenerateMonoStrictContextConstructors(database)); 172 _namespace.Types.Add(GenerateNotMonoStrictContextConstructors(database)); 173 #endif 174 175 foreach (Table table in database.Tables) 176 _namespace.Types.Add(GenerateTableClass(table, database)); 177 return _namespace; 178 } 179 AddConditionalImports(CodeNamespaceImportCollection imports, string firstImport, string conditional, string[] importsIfTrue, string[] importsIfFalse, string lastImport)180 void AddConditionalImports(CodeNamespaceImportCollection imports, 181 string firstImport, 182 string conditional, 183 string[] importsIfTrue, 184 string[] importsIfFalse, 185 string lastImport) 186 { 187 if (Provider is CSharpCodeProvider) 188 { 189 // HACK HACK HACK 190 // Would be better if CodeDom actually supported conditional compilation constructs... 191 // This is predecated upon CSharpCodeGenerator.GenerateNamespaceImport() being implemented as: 192 // output.Write ("using "); 193 // output.Write (GetSafeName (import.Namespace)); 194 // output.WriteLine (';'); 195 // Thus, with "crafty" execution of the namespace, we can stuff arbitrary text in there... 196 197 var block = new StringBuilder(); 198 // No 'using', as GenerateNamespaceImport() writes it. 199 block.Append(firstImport).Append(";").Append(Environment.NewLine); 200 block.Append("#if ").Append(conditional).Append(Environment.NewLine); 201 foreach (var ns in importsIfTrue) 202 block.Append("\tusing ").Append(ns).Append(";").Append(Environment.NewLine); 203 block.Append("#else // ").Append(conditional).Append(Environment.NewLine); 204 foreach (var ns in importsIfFalse) 205 block.Append("\tusing ").Append(ns).Append(";").Append(Environment.NewLine); 206 block.Append("#endif // ").Append(conditional).Append(Environment.NewLine); 207 block.Append("\tusing ").Append(lastImport); 208 // No ';', as GenerateNamespaceImport() writes it. 209 210 imports.Add(new CodeNamespaceImport(block.ToString())); 211 } 212 else if (Provider is VBCodeProvider) 213 { 214 // HACK HACK HACK 215 // Would be better if CodeDom actually supported conditional compilation constructs... 216 // This is predecated upon VBCodeGenerator.GenerateNamespaceImport() being implemented as: 217 // output.Write ("Imports "); 218 // output.Write (import.Namespace); 219 // output.WriteLine (); 220 // Thus, with "crafty" execution of the namespace, we can stuff arbitrary text in there... 221 222 var block = new StringBuilder(); 223 // No 'Imports', as GenerateNamespaceImport() writes it. 224 block.Append(firstImport).Append(Environment.NewLine); 225 block.Append("#If ").Append(conditional).Append(" Then").Append(Environment.NewLine); 226 foreach (var ns in importsIfTrue) 227 block.Append("Imports ").Append(ns).Append(Environment.NewLine); 228 block.Append("#Else ' ").Append(conditional).Append(Environment.NewLine); 229 foreach (var ns in importsIfFalse) 230 block.Append("Imports ").Append(ns).Append(Environment.NewLine); 231 block.Append("#End If ' ").Append(conditional).Append(Environment.NewLine); 232 block.Append("Imports ").Append(lastImport); 233 // No newline, as GenerateNamespaceImport() writes it. 234 235 imports.Add(new CodeNamespaceImport(block.ToString())); 236 } 237 else 238 { 239 // Default to using the DbLinq imports 240 imports.Add(new CodeNamespaceImport(firstImport)); 241 foreach (var ns in importsIfTrue) 242 imports.Add(new CodeNamespaceImport(ns)); 243 imports.Add(new CodeNamespaceImport(lastImport)); 244 } 245 } 246 GenerateCommentBanner(Database database, string time)247 private string GenerateCommentBanner(Database database, string time) 248 { 249 var result = new StringBuilder(); 250 251 // http://www.network-science.de/ascii/ 252 // http://www.network-science.de/ascii/ascii.php?TEXT=MetalSequel&x=14&y=14&FONT=_all+fonts+with+your+text_&RICH=no&FORM=left&STRE=no&WIDT=80 253 result.Append( 254 @" 255 ____ _ __ __ _ _ 256 | _ \| |__ | \/ | ___| |_ __ _| | 257 | | | | '_ \| |\/| |/ _ \ __/ _` | | 258 | |_| | |_) | | | | __/ || (_| | | 259 |____/|_.__/|_| |_|\___|\__\__,_|_| 260 261 "); 262 result.AppendLine(String.Format(" Auto-generated from {0} on {1}.", database.Name, time)); 263 result.AppendLine(" Please visit http://code.google.com/p/dblinq2007/ for more information."); 264 265 return result.ToString(); 266 } 267 GenerateContextClass(Database database)268 protected virtual CodeTypeDeclaration GenerateContextClass(Database database) 269 { 270 var _class = new CodeTypeDeclaration() { 271 IsClass = true, 272 IsPartial = true, 273 Name = database.Class, 274 TypeAttributes = TypeAttributes.Public 275 }; 276 277 _class.BaseTypes.Add(GetContextBaseType(database.BaseType)); 278 279 var onCreated = CreatePartialMethod("OnCreated"); 280 onCreated.StartDirectives.Add(new CodeRegionDirective(CodeRegionMode.Start, "Extensibility Method Declarations")); 281 onCreated.EndDirectives.Add(new CodeRegionDirective(CodeRegionMode.End, null)); 282 _class.Members.Add(onCreated); 283 284 // Implement Constructor 285 GenerateContextConstructors(_class, database); 286 287 foreach (Table table in database.Tables) 288 { 289 var tableType = new CodeTypeReference(table.Type.Name); 290 var property = new CodeMemberProperty() { 291 Attributes = MemberAttributes.Public | MemberAttributes.Final, 292 Name = table.Member, 293 Type = new CodeTypeReference("Table", tableType), 294 }; 295 property.GetStatements.Add( 296 new CodeMethodReturnStatement( 297 new CodeMethodInvokeExpression( 298 new CodeMethodReferenceExpression(thisReference, "GetTable", tableType)))); 299 _class.Members.Add(property); 300 } 301 302 foreach (var function in database.Functions) 303 { 304 GenerateContextFunction(_class, function); 305 } 306 307 return _class; 308 } 309 GetContextBaseType(string type)310 static string GetContextBaseType(string type) 311 { 312 string baseType = "DataContext"; 313 314 if (!string.IsNullOrEmpty(type)) 315 { 316 var t = TypeLoader.Load(type); 317 if (t != null) 318 baseType = t.Name; 319 } 320 321 return baseType; 322 } 323 GenerateContextConstructors(CodeTypeDeclaration contextType, Database database)324 void GenerateContextConstructors(CodeTypeDeclaration contextType, Database database) 325 { 326 // .ctor(string connectionString); 327 var constructor = new CodeConstructor() { 328 Attributes = MemberAttributes.Public, 329 Parameters = { new CodeParameterDeclarationExpression(typeof(string), "connectionString") }, 330 }; 331 constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("connectionString")); 332 constructor.Statements.Add(new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "OnCreated"))); 333 contextType.Members.Add(constructor); 334 335 #if MONO_STRICT 336 // .ctor(IDbConnection connection); 337 constructor = new CodeConstructor() { 338 Attributes = MemberAttributes.Public, 339 Parameters = { new CodeParameterDeclarationExpression("IDbConnection", "connection") }, 340 }; 341 constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("connection")); 342 constructor.Statements.Add(new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "OnCreated"))); 343 contextType.Members.Add(constructor); 344 #endif 345 346 // .ctor(string connection, MappingSource mappingSource); 347 constructor = new CodeConstructor() { 348 Attributes = MemberAttributes.Public, 349 Parameters = { 350 new CodeParameterDeclarationExpression(typeof(string), "connection"), 351 new CodeParameterDeclarationExpression("MappingSource", "mappingSource"), 352 }, 353 }; 354 constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("connection")); 355 constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("mappingSource")); 356 constructor.Statements.Add(new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "OnCreated"))); 357 contextType.Members.Add(constructor); 358 359 // .ctor(IDbConnection connection, MappingSource mappingSource); 360 constructor = new CodeConstructor() { 361 Attributes = MemberAttributes.Public, 362 Parameters = { 363 new CodeParameterDeclarationExpression("IDbConnection", "connection"), 364 new CodeParameterDeclarationExpression("MappingSource", "mappingSource"), 365 }, 366 }; 367 constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("connection")); 368 constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("mappingSource")); 369 constructor.Statements.Add(new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "OnCreated"))); 370 contextType.Members.Add(constructor); 371 } 372 GenerateMonoStrictContextConstructors(Database database)373 CodeTypeDeclaration GenerateMonoStrictContextConstructors(Database database) 374 { 375 var contextType = new CodeTypeDeclaration() 376 { 377 IsClass = true, 378 IsPartial = true, 379 Name = database.Class, 380 TypeAttributes = TypeAttributes.Public 381 }; 382 AddConditionalIfElseBlocks(contextType, "MONO_STRICT"); 383 384 // .ctor(IDbConnection connection); 385 var constructor = new CodeConstructor() { 386 Attributes = MemberAttributes.Public, 387 Parameters = { new CodeParameterDeclarationExpression("IDbConnection", "connection") }, 388 }; 389 constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("connection")); 390 constructor.Statements.Add(new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "OnCreated"))); 391 contextType.Members.Add(constructor); 392 393 return contextType; 394 } 395 AddConditionalIfElseBlocks(CodeTypeMember member, string condition)396 void AddConditionalIfElseBlocks(CodeTypeMember member, string condition) 397 { 398 string startIf = null, elseIf = null; 399 if (Provider is CSharpCodeProvider) 400 { 401 startIf = string.Format("Start {0}{1}#if {0}{1}", condition, Environment.NewLine); 402 elseIf = string.Format("End {0}{1}\t#endregion{1}#else // {0}", condition, Environment.NewLine); 403 } 404 if (Provider is VBCodeProvider) 405 { 406 startIf = string.Format("Start {0}\"{1}#If {0} Then{1} '", condition, Environment.NewLine); 407 elseIf = string.Format("End {0}\"{1}\t#End Region{1}#Else ' {0}", condition, Environment.NewLine); 408 } 409 if (startIf != null && elseIf != null) 410 { 411 member.StartDirectives.Add(new CodeRegionDirective(CodeRegionMode.Start, startIf)); 412 member.EndDirectives.Add(new CodeRegionDirective(CodeRegionMode.Start, elseIf)); 413 } 414 } 415 AddConditionalEndifBlocks(CodeTypeMember member, string condition)416 void AddConditionalEndifBlocks(CodeTypeMember member, string condition) 417 { 418 string endIf = null; 419 if (Provider is CSharpCodeProvider) 420 { 421 endIf = string.Format("End Not {0}{1}\t#endregion{1}#endif // {0}", condition, Environment.NewLine); 422 } 423 if (Provider is VBCodeProvider) 424 { 425 endIf = string.Format("End Not {0}\"{1}\t#End Region{1}#End If ' {0}", condition, Environment.NewLine); 426 } 427 if (endIf != null) 428 { 429 member.EndDirectives.Add(new CodeRegionDirective(CodeRegionMode.Start, endIf)); 430 member.EndDirectives.Add(new CodeRegionDirective(CodeRegionMode.End, null)); 431 } 432 } 433 GenerateNotMonoStrictContextConstructors(Database database)434 CodeTypeDeclaration GenerateNotMonoStrictContextConstructors(Database database) 435 { 436 var contextType = new CodeTypeDeclaration() { 437 IsClass = true, 438 IsPartial = true, 439 Name = database.Class, 440 TypeAttributes = TypeAttributes.Public 441 }; 442 AddConditionalEndifBlocks(contextType, "MONO_STRICT"); 443 444 // .ctor(IDbConnection connection); 445 var constructor = new CodeConstructor() { 446 Attributes = MemberAttributes.Public, 447 Parameters = { new CodeParameterDeclarationExpression("IDbConnection", "connection") }, 448 }; 449 constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("connection")); 450 constructor.BaseConstructorArgs.Add(new CodeObjectCreateExpression(Context.SchemaLoader.Vendor.GetType())); 451 constructor.Statements.Add(new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "OnCreated"))); 452 contextType.Members.Add(constructor); 453 454 // .ctor(IDbConnection connection, IVendor mappingSource); 455 constructor = new CodeConstructor() { 456 Attributes = MemberAttributes.Public, 457 Parameters = { 458 new CodeParameterDeclarationExpression("IDbConnection", "connection"), 459 new CodeParameterDeclarationExpression("IVendor", "sqlDialect"), 460 }, 461 }; 462 constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("connection")); 463 constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("sqlDialect")); 464 constructor.Statements.Add(new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "OnCreated"))); 465 contextType.Members.Add(constructor); 466 467 // .ctor(IDbConnection connection, MappingSource mappingSource, IVendor mappingSource); 468 constructor = new CodeConstructor() { 469 Attributes = MemberAttributes.Public, 470 Parameters = { 471 new CodeParameterDeclarationExpression("IDbConnection", "connection"), 472 new CodeParameterDeclarationExpression("MappingSource", "mappingSource"), 473 new CodeParameterDeclarationExpression("IVendor", "sqlDialect"), 474 }, 475 }; 476 constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("connection")); 477 constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("mappingSource")); 478 constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("sqlDialect")); 479 constructor.Statements.Add(new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "OnCreated"))); 480 contextType.Members.Add(constructor); 481 482 return contextType; 483 } 484 GenerateContextFunction(CodeTypeDeclaration contextType, Function function)485 void GenerateContextFunction(CodeTypeDeclaration contextType, Function function) 486 { 487 if (function == null || string.IsNullOrEmpty(function.Name)) 488 { 489 Warning("L33 Invalid storedProcdure object: missing name."); 490 return; 491 } 492 493 var methodRetType = GetFunctionReturnType(function); 494 var method = new CodeMemberMethod() { 495 Attributes = ToMemberAttributes(function), 496 Name = function.Method ?? function.Name, 497 ReturnType = methodRetType, 498 CustomAttributes = { 499 new CodeAttributeDeclaration("Function", 500 new CodeAttributeArgument("Name", new CodePrimitiveExpression(function.Name)), 501 new CodeAttributeArgument("IsComposable", new CodePrimitiveExpression(function.IsComposable))), 502 }, 503 }; 504 if (method.Parameters != null) 505 method.Parameters.AddRange(function.Parameters.Select(x => GetFunctionParameterType(x)).ToArray()); 506 if (function.Return != null && !string.IsNullOrEmpty(function.Return.DbType)) 507 method.ReturnTypeCustomAttributes.Add( 508 new CodeAttributeDeclaration("Parameter", 509 new CodeAttributeArgument("DbType", new CodePrimitiveExpression(function.Return.DbType)))); 510 511 contextType.Members.Add(method); 512 513 for (int i = 0; i < function.Parameters.Count; ++i) 514 { 515 var p = function.Parameters[i]; 516 if (!p.DirectionOut) 517 continue; 518 method.Statements.Add( 519 new CodeAssignStatement( 520 new CodeVariableReferenceExpression(p.Name), 521 new CodeDefaultValueExpression(new CodeTypeReference(p.Type)))); 522 } 523 524 var executeMethodCallArgs = new List<CodeExpression>() { 525 thisReference, 526 new CodeCastExpression( 527 new CodeTypeReference("System.Reflection.MethodInfo"), 528 new CodeMethodInvokeExpression( 529 new CodeMethodReferenceExpression( 530 new CodeTypeReferenceExpression("System.Reflection.MethodBase"), "GetCurrentMethod"))), 531 }; 532 if (method.Parameters != null) 533 executeMethodCallArgs.AddRange( 534 function.Parameters.Select(p => (CodeExpression) new CodeVariableReferenceExpression(p.Name))); 535 method.Statements.Add( 536 new CodeVariableDeclarationStatement( 537 new CodeTypeReference("IExecuteResult"), 538 "result", 539 new CodeMethodInvokeExpression( 540 new CodeMethodReferenceExpression(thisReference, "ExecuteMethodCall"), 541 executeMethodCallArgs.ToArray()))); 542 for (int i = 0; i < function.Parameters.Count; ++i) 543 { 544 var p = function.Parameters[i]; 545 if (!p.DirectionOut) 546 continue; 547 method.Statements.Add( 548 new CodeAssignStatement( 549 new CodeVariableReferenceExpression(p.Name), 550 new CodeCastExpression( 551 new CodeTypeReference(p.Type), 552 new CodeMethodInvokeExpression( 553 new CodeMethodReferenceExpression( 554 new CodeVariableReferenceExpression("result"), 555 "GetParameterValue"), 556 new CodePrimitiveExpression(i))))); 557 } 558 559 if (methodRetType != null) 560 { 561 method.Statements.Add( 562 new CodeMethodReturnStatement( 563 new CodeCastExpression( 564 method.ReturnType, 565 new CodePropertyReferenceExpression( 566 new CodeVariableReferenceExpression("result"), 567 "ReturnValue")))); 568 } 569 } 570 GetFunctionReturnType(Function function)571 CodeTypeReference GetFunctionReturnType(Function function) 572 { 573 CodeTypeReference type = null; 574 if (function.Return != null) 575 { 576 type = GetFunctionType(function.Return.Type); 577 } 578 579 bool isDataShapeUnknown = function.ElementType == null 580 && function.BodyContainsSelectStatement 581 && !function.IsComposable; 582 if (isDataShapeUnknown) 583 { 584 //if we don't know the shape of results, and the proc body contains some selects, 585 //we have no choice but to return an untyped DataSet. 586 // 587 //TODO: either parse proc body like microsoft, 588 //or create a little GUI tool which would call the proc with test values, to determine result shape. 589 type = new CodeTypeReference(typeof(DataSet)); 590 } 591 return type; 592 } 593 GetFunctionType(string type)594 static CodeTypeReference GetFunctionType(string type) 595 { 596 var t = System.Type.GetType(type); 597 if (t == null) 598 return new CodeTypeReference(type); 599 if (t.IsValueType) 600 return new CodeTypeReference(typeof(Nullable<>)) { 601 TypeArguments = { 602 new CodeTypeReference(t), 603 }, 604 }; 605 return new CodeTypeReference(t); 606 } 607 GetFunctionParameterType(Parameter parameter)608 CodeParameterDeclarationExpression GetFunctionParameterType(Parameter parameter) 609 { 610 var p = new CodeParameterDeclarationExpression(GetFunctionType(parameter.Type), parameter.Name) { 611 CustomAttributes = { 612 new CodeAttributeDeclaration("Parameter", 613 new CodeAttributeArgument("Name", new CodePrimitiveExpression(parameter.Name)), 614 new CodeAttributeArgument("DbType", new CodePrimitiveExpression(parameter.DbType))), 615 }, 616 }; 617 switch (parameter.Direction) 618 { 619 case DbLinq.Schema.Dbml.ParameterDirection.In: 620 p.Direction = FieldDirection.In; 621 break; 622 case DbLinq.Schema.Dbml.ParameterDirection.Out: 623 p.Direction = FieldDirection.Out; 624 break; 625 case DbLinq.Schema.Dbml.ParameterDirection.InOut: 626 p.Direction = FieldDirection.In | FieldDirection.Out; 627 break; 628 default: 629 throw new ArgumentOutOfRangeException(); 630 } 631 return p; 632 } 633 GenerateTableClass(Table table, Database database)634 protected CodeTypeDeclaration GenerateTableClass(Table table, Database database) 635 { 636 var _class = new CodeTypeDeclaration() { 637 IsClass = true, 638 IsPartial = true, 639 Name = table.Type.Name, 640 TypeAttributes = TypeAttributes.Public, 641 CustomAttributes = { 642 new CodeAttributeDeclaration("Table", 643 new CodeAttributeArgument("Name", new CodePrimitiveExpression(table.Name))), 644 }, 645 }; 646 647 WriteCustomTypes(_class, table); 648 649 var havePrimaryKeys = table.Type.Columns.Any(c => c.IsPrimaryKey); 650 if (havePrimaryKeys) 651 { 652 GenerateINotifyPropertyChanging(_class); 653 GenerateINotifyPropertyChanged(_class); 654 } 655 656 // Implement Constructor 657 var constructor = new CodeConstructor() { Attributes = MemberAttributes.Public }; 658 // children are EntitySet 659 foreach (var child in GetClassChildren(table)) 660 { 661 // if the association has a storage, we use it. Otherwise, we use the property name 662 var entitySetMember = GetStorageFieldName(child); 663 constructor.Statements.Add( 664 new CodeAssignStatement( 665 new CodeVariableReferenceExpression(entitySetMember), 666 new CodeObjectCreateExpression( 667 new CodeTypeReference("EntitySet", new CodeTypeReference(child.Type)), 668 new CodeDelegateCreateExpression( 669 new CodeTypeReference("Action", new CodeTypeReference(child.Type)), 670 thisReference, child.Member + "_Attach"), 671 new CodeDelegateCreateExpression( 672 new CodeTypeReference("Action", new CodeTypeReference(child.Type)), 673 thisReference, child.Member + "_Detach")))); 674 } 675 constructor.Statements.Add(new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "OnCreated"))); 676 _class.Members.Add(constructor); 677 678 if (Context.Parameters.GenerateEqualsHash) 679 { 680 GenerateEntityGetHashCodeAndEquals(_class, table); 681 } 682 683 GenerateExtensibilityDeclarations(_class, table); 684 685 // todo: add these when the actually get called 686 //partial void OnLoaded(); 687 //partial void OnValidate(System.Data.Linq.ChangeAction action); 688 689 // columns 690 foreach (Column column in table.Type.Columns) 691 { 692 var relatedAssociations = from a in table.Type.Associations 693 where a.IsForeignKey && a.TheseKeys.Contains(column.Name) 694 select a; 695 696 var type = ToCodeTypeReference(column); 697 var columnMember = column.Member ?? column.Name; 698 699 var field = new CodeMemberField(type, GetStorageFieldName(column)); 700 _class.Members.Add(field); 701 var fieldReference = new CodeFieldReferenceExpression(new CodeThisReferenceExpression(), field.Name); 702 703 var onChanging = GetChangingMethodName(columnMember); 704 var onChanged = GetChangedMethodName(columnMember); 705 706 var property = new CodeMemberProperty(); 707 property.Type = type; 708 property.Name = columnMember; 709 property.Attributes = MemberAttributes.Public | MemberAttributes.Final; 710 711 var defAttrValues = new ColumnAttribute(); 712 var args = new List<CodeAttributeArgument>() { 713 new CodeAttributeArgument("Storage", new CodePrimitiveExpression(GetStorageFieldName(column))), 714 new CodeAttributeArgument("Name", new CodePrimitiveExpression(column.Name)), 715 new CodeAttributeArgument("DbType", new CodePrimitiveExpression(column.DbType)), 716 }; 717 if (defAttrValues.IsPrimaryKey != column.IsPrimaryKey) 718 args.Add(new CodeAttributeArgument("IsPrimaryKey", new CodePrimitiveExpression(column.IsPrimaryKey))); 719 if (defAttrValues.IsDbGenerated != column.IsDbGenerated) 720 args.Add(new CodeAttributeArgument("IsDbGenerated", new CodePrimitiveExpression(column.IsDbGenerated))); 721 if (column.AutoSync != DbLinq.Schema.Dbml.AutoSync.Default) 722 args.Add(new CodeAttributeArgument("AutoSync", 723 new CodeFieldReferenceExpression(new CodeTypeReferenceExpression("AutoSync"), column.AutoSync.ToString()))); 724 if (defAttrValues.CanBeNull != column.CanBeNull) 725 args.Add(new CodeAttributeArgument("CanBeNull", new CodePrimitiveExpression(column.CanBeNull))); 726 if (column.Expression != null) 727 args.Add(new CodeAttributeArgument("Expression", new CodePrimitiveExpression(column.Expression))); 728 property.CustomAttributes.Add( 729 new CodeAttributeDeclaration("Column", args.ToArray())); 730 property.CustomAttributes.Add(new CodeAttributeDeclaration("DebuggerNonUserCode")); 731 732 property.GetStatements.Add(new CodeMethodReturnStatement(fieldReference)); 733 734 var whenUpdating = new List<CodeStatement>( 735 from assoc in relatedAssociations 736 select (CodeStatement) new CodeConditionStatement( 737 new CodePropertyReferenceExpression( 738 new CodeVariableReferenceExpression(GetStorageFieldName(assoc)), 739 "HasLoadedOrAssignedValue"), 740 new CodeThrowExceptionStatement( 741 new CodeObjectCreateExpression(typeof(System.Data.Linq.ForeignKeyReferenceAlreadyHasValueException))))); 742 whenUpdating.Add( 743 new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, onChanging, new CodePropertySetValueReferenceExpression()))); 744 if (havePrimaryKeys) 745 whenUpdating.Add( 746 new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "SendPropertyChanging"))); 747 whenUpdating.Add( 748 new CodeAssignStatement(fieldReference, new CodePropertySetValueReferenceExpression())); 749 if (havePrimaryKeys) 750 whenUpdating.Add( 751 new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "SendPropertyChanged", new CodePrimitiveExpression(property.Name)))); 752 whenUpdating.Add( 753 new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, onChanged))); 754 755 var fieldType = TypeLoader.Load(column.Type); 756 // This is needed for VB.NET generation; 757 // int/string/etc. can use '<>' for comparison, but NOT arrays and other reference types. 758 // arrays/etc. require the 'Is' operator, which is CodeBinaryOperatorType.IdentityEquality. 759 // The VB IsNot operator is not exposed from CodeDom. 760 // Thus, we need to special-case: if fieldType is a ref or nullable type, 761 // generate '(field Is value) = false'; otherwise, 762 // generate '(field <> value)' 763 CodeBinaryOperatorExpression condition = fieldType.IsClass || fieldType.IsNullable() 764 ? ValuesAreNotEqual_Ref(new CodeVariableReferenceExpression(field.Name), new CodePropertySetValueReferenceExpression()) 765 : ValuesAreNotEqual(new CodeVariableReferenceExpression(field.Name), new CodePropertySetValueReferenceExpression()); 766 property.SetStatements.Add(new CodeConditionStatement(condition, whenUpdating.ToArray())); 767 _class.Members.Add(property); 768 } 769 770 GenerateEntityChildren(_class, table, database); 771 GenerateEntityChildrenAttachment(_class, table, database); 772 GenerateEntityParents(_class, table, database); 773 774 return _class; 775 } 776 WriteCustomTypes(CodeTypeDeclaration entity, Table table)777 void WriteCustomTypes(CodeTypeDeclaration entity, Table table) 778 { 779 // detect required custom types 780 foreach (var column in table.Type.Columns) 781 { 782 var extendedType = column.ExtendedType; 783 var enumType = extendedType as EnumType; 784 if (enumType != null) 785 { 786 Context.ExtendedTypes[column] = new GenerationContext.ExtendedTypeAndName { 787 Type = column.ExtendedType, 788 Table = table 789 }; 790 } 791 } 792 793 var customTypesNames = new List<string>(); 794 795 // create names and avoid conflits 796 foreach (var extendedTypePair in Context.ExtendedTypes) 797 { 798 if (extendedTypePair.Value.Table != table) 799 continue; 800 801 if (string.IsNullOrEmpty(extendedTypePair.Value.Type.Name)) 802 { 803 string name = extendedTypePair.Key.Member + "Type"; 804 for (; ; ) 805 { 806 if ((from t in Context.ExtendedTypes.Values where t.Type.Name == name select t).FirstOrDefault() == null) 807 { 808 extendedTypePair.Value.Type.Name = name; 809 break; 810 } 811 // at 3rd loop, it will look ugly, however we will never go there 812 name = extendedTypePair.Value.Table.Type.Name + name; 813 } 814 } 815 customTypesNames.Add(extendedTypePair.Value.Type.Name); 816 } 817 818 // write custom types 819 if (customTypesNames.Count > 0) 820 { 821 var customTypes = new List<CodeTypeDeclaration>(customTypesNames.Count); 822 823 foreach (var extendedTypePair in Context.ExtendedTypes) 824 { 825 if (extendedTypePair.Value.Table != table) 826 continue; 827 828 var extendedType = extendedTypePair.Value.Type; 829 var enumValue = extendedType as EnumType; 830 831 if (enumValue != null) 832 { 833 var enumType = new CodeTypeDeclaration(enumValue.Name) { 834 TypeAttributes = TypeAttributes.Public, 835 IsEnum = true, 836 }; 837 customTypes.Add(enumType); 838 var orderedValues = from nv in enumValue orderby nv.Value select nv; 839 int currentValue = 1; 840 foreach (var nameValue in orderedValues) 841 { 842 var field = new CodeMemberField() { 843 Name = nameValue.Key, 844 }; 845 enumType.Members.Add(field); 846 if (nameValue.Value != currentValue) 847 { 848 currentValue = nameValue.Value; 849 field.InitExpression = new CodePrimitiveExpression(nameValue.Value); 850 } 851 currentValue++; 852 } 853 } 854 } 855 856 if (customTypes.Count == 0) 857 return; 858 customTypes.First().StartDirectives.Add(new CodeRegionDirective(CodeRegionMode.Start, 859 string.Format("Custom type definitions for {0}", string.Join(", ", customTypesNames.ToArray())))); 860 customTypes.Last().EndDirectives.Add(new CodeRegionDirective(CodeRegionMode.End, null)); 861 entity.Members.AddRange(customTypes.ToArray()); 862 } 863 } 864 GenerateExtensibilityDeclarations(CodeTypeDeclaration entity, Table table)865 void GenerateExtensibilityDeclarations(CodeTypeDeclaration entity, Table table) 866 { 867 var partialMethods = new[] { CreatePartialMethod("OnCreated") } 868 .Concat(table.Type.Columns.Select(c => new[] { CreateChangedMethodDecl(c), CreateChangingMethodDecl(c) }) 869 .SelectMany(md => md)).ToArray(); 870 partialMethods.First().StartDirectives.Add(new CodeRegionDirective(CodeRegionMode.Start, "Extensibility Method Declarations")); 871 partialMethods.Last().EndDirectives.Add(new CodeRegionDirective(CodeRegionMode.End, null)); 872 entity.Members.AddRange(partialMethods); 873 } 874 GetChangedMethodName(string columnName)875 static string GetChangedMethodName(string columnName) 876 { 877 return string.Format("On{0}Changed", columnName); 878 } 879 CreateChangedMethodDecl(Column column)880 CodeTypeMember CreateChangedMethodDecl(Column column) 881 { 882 return CreatePartialMethod(GetChangedMethodName(column.Member)); 883 } 884 GetChangingMethodName(string columnName)885 static string GetChangingMethodName(string columnName) 886 { 887 return string.Format("On{0}Changing", columnName); 888 } 889 CreateChangingMethodDecl(Column column)890 CodeTypeMember CreateChangingMethodDecl(Column column) 891 { 892 return CreatePartialMethod(GetChangingMethodName(column.Member), 893 new CodeParameterDeclarationExpression(ToCodeTypeReference(column), "value")); 894 } 895 ToCodeTypeReference(Column column)896 static CodeTypeReference ToCodeTypeReference(Column column) 897 { 898 var t = System.Type.GetType(column.Type); 899 if (t == null) 900 return new CodeTypeReference(column.Type); 901 return t.IsValueType && column.CanBeNull 902 ? new CodeTypeReference("System.Nullable", new CodeTypeReference(column.Type)) 903 : new CodeTypeReference(column.Type); 904 } 905 ValuesAreNotEqual(CodeExpression a, CodeExpression b)906 CodeBinaryOperatorExpression ValuesAreNotEqual(CodeExpression a, CodeExpression b) 907 { 908 return new CodeBinaryOperatorExpression(a, CodeBinaryOperatorType.IdentityInequality, b); 909 } 910 ValuesAreNotEqual_Ref(CodeExpression a, CodeExpression b)911 CodeBinaryOperatorExpression ValuesAreNotEqual_Ref(CodeExpression a, CodeExpression b) 912 { 913 return new CodeBinaryOperatorExpression( 914 new CodeBinaryOperatorExpression( 915 a, 916 CodeBinaryOperatorType.IdentityEquality, 917 b), 918 CodeBinaryOperatorType.ValueEquality, 919 new CodePrimitiveExpression(false)); 920 } 921 ValueIsNull(CodeExpression value)922 CodeBinaryOperatorExpression ValueIsNull(CodeExpression value) 923 { 924 return new CodeBinaryOperatorExpression( 925 value, 926 CodeBinaryOperatorType.IdentityEquality, 927 new CodePrimitiveExpression(null)); 928 } 929 ValueIsNotNull(CodeExpression value)930 CodeBinaryOperatorExpression ValueIsNotNull(CodeExpression value) 931 { 932 return new CodeBinaryOperatorExpression( 933 value, 934 CodeBinaryOperatorType.IdentityInequality, 935 new CodePrimitiveExpression(null)); 936 } 937 GetStorageFieldName(Column column)938 static string GetStorageFieldName(Column column) 939 { 940 return GetStorageFieldName(column.Storage ?? column.Member); 941 } 942 GetStorageFieldName(string storage)943 static string GetStorageFieldName(string storage) 944 { 945 if (storage.StartsWith("_")) 946 return storage; 947 return "_" + storage; 948 } 949 GenerateINotifyPropertyChanging(CodeTypeDeclaration entity)950 private void GenerateINotifyPropertyChanging(CodeTypeDeclaration entity) 951 { 952 entity.BaseTypes.Add(typeof(INotifyPropertyChanging)); 953 var propertyChangingEvent = new CodeMemberEvent() { 954 Attributes = MemberAttributes.Public, 955 Name = "PropertyChanging", 956 Type = new CodeTypeReference(typeof(PropertyChangingEventHandler)), 957 ImplementationTypes = { 958 new CodeTypeReference(typeof(INotifyPropertyChanging)) 959 }, 960 }; 961 var eventArgs = new CodeMemberField(new CodeTypeReference(typeof(PropertyChangingEventArgs)), "emptyChangingEventArgs") { 962 Attributes = MemberAttributes.Static | MemberAttributes.Private, 963 InitExpression = new CodeObjectCreateExpression(new CodeTypeReference(typeof(PropertyChangingEventArgs)), 964 new CodePrimitiveExpression("")), 965 }; 966 var method = new CodeMemberMethod() { 967 Attributes = MemberAttributes.Family, 968 Name = "SendPropertyChanging", 969 }; 970 method.Statements.Add(new CodeVariableDeclarationStatement(typeof(PropertyChangingEventHandler), "h") { 971 InitExpression = new CodeEventReferenceExpression(thisReference, "PropertyChanging"), 972 }); 973 method.Statements.Add(new CodeConditionStatement( 974 ValueIsNotNull(new CodeVariableReferenceExpression("h")), 975 new CodeExpressionStatement( 976 new CodeDelegateInvokeExpression(new CodeVariableReferenceExpression("h"), thisReference, new CodeFieldReferenceExpression(null, "emptyChangingEventArgs"))))); 977 978 entity.Members.Add(propertyChangingEvent); 979 entity.Members.Add(eventArgs); 980 entity.Members.Add(method); 981 } 982 GenerateINotifyPropertyChanged(CodeTypeDeclaration entity)983 private void GenerateINotifyPropertyChanged(CodeTypeDeclaration entity) 984 { 985 entity.BaseTypes.Add(typeof(INotifyPropertyChanged)); 986 987 var propertyChangedEvent = new CodeMemberEvent() { 988 Attributes = MemberAttributes.Public, 989 Name = "PropertyChanged", 990 Type = new CodeTypeReference(typeof(PropertyChangedEventHandler)), 991 ImplementationTypes = { 992 new CodeTypeReference(typeof(INotifyPropertyChanged)) 993 }, 994 }; 995 996 var method = new CodeMemberMethod() { 997 Attributes = MemberAttributes.Family, 998 Name = "SendPropertyChanged", 999 Parameters = { new CodeParameterDeclarationExpression(typeof(System.String), "propertyName") } 1000 }; 1001 method.Statements.Add(new CodeVariableDeclarationStatement(typeof(PropertyChangedEventHandler), "h") { 1002 InitExpression = new CodeEventReferenceExpression(thisReference, "PropertyChanged"), 1003 }); 1004 method.Statements.Add(new CodeConditionStatement( 1005 ValueIsNotNull(new CodeVariableReferenceExpression("h")), 1006 new CodeExpressionStatement( 1007 new CodeDelegateInvokeExpression(new CodeVariableReferenceExpression("h"), thisReference, new CodeObjectCreateExpression(typeof(PropertyChangedEventArgs), new CodeVariableReferenceExpression("propertyName")))))); 1008 1009 entity.Members.Add(propertyChangedEvent); 1010 entity.Members.Add(method); 1011 } 1012 GenerateEntityGetHashCodeAndEquals(CodeTypeDeclaration entity, Table table)1013 void GenerateEntityGetHashCodeAndEquals(CodeTypeDeclaration entity, Table table) 1014 { 1015 var primaryKeys = table.Type.Columns.Where(c => c.IsPrimaryKey); 1016 var pkCount = primaryKeys.Count(); 1017 if (pkCount == 0) 1018 { 1019 Warning("Table {0} has no primary key(s). Skipping /generate-equals-hash for this table.", 1020 table.Name); 1021 return; 1022 } 1023 entity.BaseTypes.Add(new CodeTypeReference(typeof(IEquatable<>)) { 1024 TypeArguments = { new CodeTypeReference(entity.Name) }, 1025 }); 1026 1027 var method = new CodeMemberMethod() { 1028 Attributes = MemberAttributes.Public | MemberAttributes.Override, 1029 Name = "GetHashCode", 1030 ReturnType = new CodeTypeReference(typeof(int)), 1031 }; 1032 entity.Members.Add(method); 1033 method.Statements.Add(new CodeVariableDeclarationStatement(typeof(int), "hc", new CodePrimitiveExpression(0))); 1034 var numShifts = 32 / pkCount; 1035 int pki = 0; 1036 foreach (var pk in primaryKeys) 1037 { 1038 var shift = 1 << (pki++ * numShifts); 1039 // lack of exclusive-or means we instead split the 32-bit hash code value 1040 // into pkCount "chunks", each chunk being numShifts in size. 1041 // Thus, if there are two primary keys, the first primary key gets the 1042 // lower 16 bits, while the second primray key gets the upper 16 bits. 1043 CodeStatement update = new CodeAssignStatement( 1044 new CodeVariableReferenceExpression("hc"), 1045 new CodeBinaryOperatorExpression( 1046 new CodeVariableReferenceExpression("hc"), 1047 CodeBinaryOperatorType.BitwiseOr, 1048 new CodeBinaryOperatorExpression( 1049 new CodeMethodInvokeExpression( 1050 new CodeMethodReferenceExpression( 1051 new CodeVariableReferenceExpression(GetStorageFieldName(pk)), 1052 "GetHashCode")), 1053 CodeBinaryOperatorType.Multiply, 1054 new CodePrimitiveExpression(shift)))); 1055 var pkType = System.Type.GetType(pk.Type); 1056 if (pk.CanBeNull || (pkType != null && (pkType.IsClass || pkType.IsNullable()))) 1057 { 1058 update = new CodeConditionStatement( 1059 ValueIsNotNull(new CodeVariableReferenceExpression(GetStorageFieldName(pk))), 1060 update); 1061 } 1062 method.Statements.Add(update); 1063 } 1064 method.Statements.Add(new CodeMethodReturnStatement(new CodeVariableReferenceExpression("hc"))); 1065 1066 method = new CodeMemberMethod() { 1067 Attributes = MemberAttributes.Public | MemberAttributes.Override, 1068 Name = "Equals", 1069 ReturnType = new CodeTypeReference(typeof(bool)), 1070 Parameters = { 1071 new CodeParameterDeclarationExpression(new CodeTypeReference(typeof(object)), "value"), 1072 }, 1073 }; 1074 entity.Members.Add(method); 1075 method.Statements.Add( 1076 new CodeConditionStatement( 1077 ValueIsNull(new CodeVariableReferenceExpression("value")), 1078 new CodeMethodReturnStatement(new CodePrimitiveExpression(false)))); 1079 method.Statements.Add( 1080 new CodeConditionStatement( 1081 ValuesAreNotEqual_Ref( 1082 new CodeMethodInvokeExpression( 1083 new CodeMethodReferenceExpression( 1084 new CodeVariableReferenceExpression("value"), 1085 "GetType")), 1086 new CodeMethodInvokeExpression( 1087 new CodeMethodReferenceExpression(thisReference, "GetType"))), 1088 new CodeMethodReturnStatement(new CodePrimitiveExpression(false)))); 1089 method.Statements.Add( 1090 new CodeVariableDeclarationStatement( 1091 new CodeTypeReference(entity.Name), 1092 "other", 1093 new CodeCastExpression(new CodeTypeReference(entity.Name), new CodeVariableReferenceExpression("value")))); 1094 method.Statements.Add( 1095 new CodeMethodReturnStatement( 1096 new CodeMethodInvokeExpression( 1097 new CodeMethodReferenceExpression(thisReference, "Equals"), 1098 new CodeVariableReferenceExpression("other")))); 1099 1100 method = new CodeMemberMethod() { 1101 Attributes = MemberAttributes.Public, 1102 Name = "Equals", 1103 ReturnType = new CodeTypeReference(typeof(bool)), 1104 Parameters = { 1105 new CodeParameterDeclarationExpression(new CodeTypeReference(entity.Name), "value"), 1106 }, 1107 ImplementationTypes = { 1108 new CodeTypeReference("IEquatable", new CodeTypeReference(entity.Name)), 1109 }, 1110 }; 1111 entity.Members.Add(method); 1112 method.Statements.Add( 1113 new CodeConditionStatement( 1114 ValueIsNull(new CodeVariableReferenceExpression("value")), 1115 new CodeMethodReturnStatement(new CodePrimitiveExpression(false)))); 1116 1117 CodeExpression equals = null; 1118 foreach (var pk in primaryKeys) 1119 { 1120 var compare = new CodeMethodInvokeExpression( 1121 new CodeMethodReferenceExpression( 1122 new CodePropertyReferenceExpression( 1123 new CodeTypeReferenceExpression( 1124 new CodeTypeReference("System.Collections.Generic.EqualityComparer", 1125 new CodeTypeReference(pk.Type))), 1126 "Default"), 1127 "Equals"), 1128 new CodeFieldReferenceExpression(thisReference, GetStorageFieldName(pk)), 1129 new CodeFieldReferenceExpression(new CodeVariableReferenceExpression("value"), GetStorageFieldName(pk))); 1130 equals = equals == null 1131 ? (CodeExpression) compare 1132 : (CodeExpression) new CodeBinaryOperatorExpression( 1133 equals, 1134 CodeBinaryOperatorType.BooleanAnd, 1135 compare); 1136 } 1137 method.Statements.Add( 1138 new CodeMethodReturnStatement(equals)); 1139 } 1140 GenerateEntityChildren(CodeTypeDeclaration entity, Table table, Database schema)1141 void GenerateEntityChildren(CodeTypeDeclaration entity, Table table, Database schema) 1142 { 1143 var children = GetClassChildren(table); 1144 if (children.Any()) 1145 { 1146 var childMembers = new List<CodeTypeMember>(); 1147 1148 foreach (var child in children) 1149 { 1150 bool hasDuplicates = (from c in children where c.Member == child.Member select c).Count() > 1; 1151 1152 // the following is apparently useless 1153 var targetTable = schema.Tables.FirstOrDefault(t => t.Type.Name == child.Type); 1154 if (targetTable == null) 1155 { 1156 //Logger.Write(Level.Error, "ERROR L143 target table class not found:" + child.Type); 1157 continue; 1158 } 1159 1160 var childType = new CodeTypeReference("EntitySet", new CodeTypeReference(child.Type)); 1161 var storage = GetStorageFieldName(child); 1162 entity.Members.Add(new CodeMemberField(childType, storage)); 1163 1164 var childName = hasDuplicates 1165 ? child.Member + "_" + string.Join("", child.OtherKeys.ToArray()) 1166 : child.Member; 1167 var property = new CodeMemberProperty() { 1168 Name = childName, 1169 Type = childType, 1170 Attributes = ToMemberAttributes(child), 1171 CustomAttributes = { 1172 new CodeAttributeDeclaration("Association", 1173 new CodeAttributeArgument("Storage", new CodePrimitiveExpression(GetStorageFieldName(child))), 1174 new CodeAttributeArgument("OtherKey", new CodePrimitiveExpression(child.OtherKey)), 1175 new CodeAttributeArgument("ThisKey", new CodePrimitiveExpression(child.ThisKey)), 1176 new CodeAttributeArgument("Name", new CodePrimitiveExpression(child.Name))), 1177 new CodeAttributeDeclaration("DebuggerNonUserCode"), 1178 }, 1179 }; 1180 childMembers.Add(property); 1181 property.GetStatements.Add(new CodeMethodReturnStatement( 1182 new CodeFieldReferenceExpression(thisReference, storage))); 1183 property.SetStatements.Add(new CodeAssignStatement( 1184 new CodeFieldReferenceExpression(thisReference, storage), 1185 new CodePropertySetValueReferenceExpression())); 1186 } 1187 1188 if (childMembers.Count == 0) 1189 return; 1190 childMembers.First().StartDirectives.Add(new CodeRegionDirective(CodeRegionMode.Start, "Children")); 1191 childMembers.Last().EndDirectives.Add(new CodeRegionDirective(CodeRegionMode.End, null)); 1192 entity.Members.AddRange(childMembers.ToArray()); 1193 } 1194 } 1195 GetClassChildren(Table table)1196 IEnumerable<Association> GetClassChildren(Table table) 1197 { 1198 return table.Type.Associations.Where(a => !a.IsForeignKey); 1199 } 1200 ToMemberAttributes(Association association)1201 static MemberAttributes ToMemberAttributes(Association association) 1202 { 1203 MemberAttributes attrs = 0; 1204 if (!association.AccessModifierSpecified) 1205 attrs |= MemberAttributes.Public; 1206 else 1207 switch (association.AccessModifier) 1208 { 1209 case AccessModifier.Internal: attrs = MemberAttributes.Assembly; break; 1210 case AccessModifier.Private: attrs = MemberAttributes.Private; break; 1211 case AccessModifier.Protected: attrs = MemberAttributes.Family; break; 1212 case AccessModifier.ProtectedInternal: attrs = MemberAttributes.FamilyOrAssembly; break; 1213 case AccessModifier.Public: attrs = MemberAttributes.Public; break; 1214 default: 1215 throw new ArgumentOutOfRangeException("association", "Modifier value '" + association.AccessModifierSpecified + "' is an unsupported value."); 1216 } 1217 if (!association.ModifierSpecified) 1218 attrs |= MemberAttributes.Final; 1219 else 1220 switch (association.Modifier) 1221 { 1222 case MemberModifier.New: attrs |= MemberAttributes.New | MemberAttributes.Final; break; 1223 case MemberModifier.NewVirtual: attrs |= MemberAttributes.New; break; 1224 case MemberModifier.Override: attrs |= MemberAttributes.Override; break; 1225 case MemberModifier.Virtual: break; 1226 } 1227 return attrs; 1228 } 1229 ToMemberAttributes(Function function)1230 static MemberAttributes ToMemberAttributes(Function function) 1231 { 1232 MemberAttributes attrs = 0; 1233 if (!function.AccessModifierSpecified) 1234 attrs |= MemberAttributes.Public; 1235 else 1236 switch (function.AccessModifier) 1237 { 1238 case AccessModifier.Internal: attrs = MemberAttributes.Assembly; break; 1239 case AccessModifier.Private: attrs = MemberAttributes.Private; break; 1240 case AccessModifier.Protected: attrs = MemberAttributes.Family; break; 1241 case AccessModifier.ProtectedInternal: attrs = MemberAttributes.FamilyOrAssembly; break; 1242 case AccessModifier.Public: attrs = MemberAttributes.Public; break; 1243 default: 1244 throw new ArgumentOutOfRangeException("function", "Modifier value '" + function.AccessModifierSpecified + "' is an unsupported value."); 1245 } 1246 if (!function.ModifierSpecified) 1247 attrs |= MemberAttributes.Final; 1248 else 1249 switch (function.Modifier) 1250 { 1251 case MemberModifier.New: attrs |= MemberAttributes.New | MemberAttributes.Final; break; 1252 case MemberModifier.NewVirtual: attrs |= MemberAttributes.New; break; 1253 case MemberModifier.Override: attrs |= MemberAttributes.Override; break; 1254 case MemberModifier.Virtual: break; 1255 } 1256 return attrs; 1257 } 1258 GetStorageFieldName(Association association)1259 static string GetStorageFieldName(Association association) 1260 { 1261 return association.Storage != null 1262 ? GetStorageFieldName(association.Storage) 1263 : "_" + CreateIdentifier(association.Member ?? association.Name); 1264 } 1265 CreateIdentifier(string value)1266 static string CreateIdentifier(string value) 1267 { 1268 return Regex.Replace(value, @"\W", "_"); 1269 } 1270 GenerateEntityChildrenAttachment(CodeTypeDeclaration entity, Table table, Database schema)1271 void GenerateEntityChildrenAttachment(CodeTypeDeclaration entity, Table table, Database schema) 1272 { 1273 var children = GetClassChildren(table).ToList(); 1274 if (!children.Any()) 1275 return; 1276 1277 var havePrimaryKeys = table.Type.Columns.Any(c => c.IsPrimaryKey); 1278 1279 var handlers = new List<CodeTypeMember>(); 1280 1281 foreach (var child in children) 1282 { 1283 // the reverse child is the association seen from the child 1284 // we're going to use it... 1285 var reverseChild = schema.GetReverseAssociation(child); 1286 // ... to get the parent name 1287 var memberName = reverseChild.Member; 1288 1289 var sendPropertyChanging = new CodeExpressionStatement( 1290 new CodeMethodInvokeExpression( 1291 new CodeMethodReferenceExpression(thisReference, "SendPropertyChanging"))); 1292 1293 var attach = new CodeMemberMethod() { 1294 Name = child.Member + "_Attach", 1295 Parameters = { 1296 new CodeParameterDeclarationExpression(child.Type, "entity"), 1297 }, 1298 }; 1299 handlers.Add(attach); 1300 if (havePrimaryKeys) 1301 attach.Statements.Add(sendPropertyChanging); 1302 attach.Statements.Add( 1303 new CodeAssignStatement( 1304 new CodePropertyReferenceExpression(new CodeVariableReferenceExpression("entity"), memberName), 1305 thisReference)); 1306 1307 var detach = new CodeMemberMethod() { 1308 Name = child.Member + "_Detach", 1309 Parameters = { 1310 new CodeParameterDeclarationExpression(child.Type, "entity"), 1311 }, 1312 }; 1313 handlers.Add(detach); 1314 if (havePrimaryKeys) 1315 detach.Statements.Add(sendPropertyChanging); 1316 detach.Statements.Add( 1317 new CodeAssignStatement( 1318 new CodePropertyReferenceExpression(new CodeVariableReferenceExpression("entity"), memberName), 1319 new CodePrimitiveExpression(null))); 1320 } 1321 1322 if (handlers.Count == 0) 1323 return; 1324 1325 handlers.First().StartDirectives.Add(new CodeRegionDirective(CodeRegionMode.Start, "Attachment handlers")); 1326 handlers.Last().EndDirectives.Add(new CodeRegionDirective(CodeRegionMode.End, null)); 1327 entity.Members.AddRange(handlers.ToArray()); 1328 } 1329 GenerateEntityParents(CodeTypeDeclaration entity, Table table, Database schema)1330 void GenerateEntityParents(CodeTypeDeclaration entity, Table table, Database schema) 1331 { 1332 var parents = table.Type.Associations.Where(a => a.IsForeignKey); 1333 if (!parents.Any()) 1334 return; 1335 1336 var parentMembers = new List<CodeTypeMember>(); 1337 1338 foreach (var parent in parents) 1339 { 1340 bool hasDuplicates = (from p in parents where p.Member == parent.Member select p).Count() > 1; 1341 // WriteClassParent(writer, parent, hasDuplicates, schema, context); 1342 // the following is apparently useless 1343 DbLinq.Schema.Dbml.Table targetTable = schema.Tables.FirstOrDefault(t => t.Type.Name == parent.Type); 1344 if (targetTable == null) 1345 { 1346 //Logger.Write(Level.Error, "ERROR L191 target table type not found: " + parent.Type + " (processing " + parent.Name + ")"); 1347 continue; 1348 } 1349 1350 string member = parent.Member; 1351 string storageField = GetStorageFieldName(parent); 1352 // TODO: remove this 1353 if (member == parent.ThisKey) 1354 { 1355 member = parent.ThisKey + targetTable.Type.Name; //repeat name to prevent collision (same as Linq) 1356 storageField = "_x_" + parent.Member; 1357 } 1358 1359 var parentType = new CodeTypeReference(targetTable.Type.Name); 1360 entity.Members.Add(new CodeMemberField(new CodeTypeReference("EntityRef", parentType), storageField) { 1361 InitExpression = new CodeObjectCreateExpression(new CodeTypeReference("EntityRef", parentType)), 1362 }); 1363 1364 var parentName = hasDuplicates 1365 ? member + "_" + string.Join("", parent.TheseKeys.ToArray()) 1366 : member; 1367 var property = new CodeMemberProperty() { 1368 Name = parentName, 1369 Type = parentType, 1370 Attributes = ToMemberAttributes(parent), 1371 CustomAttributes = { 1372 new CodeAttributeDeclaration("Association", 1373 new CodeAttributeArgument("Storage", new CodePrimitiveExpression(storageField)), 1374 new CodeAttributeArgument("OtherKey", new CodePrimitiveExpression(parent.OtherKey)), 1375 new CodeAttributeArgument("ThisKey", new CodePrimitiveExpression(parent.ThisKey)), 1376 new CodeAttributeArgument("Name", new CodePrimitiveExpression(parent.Name)), 1377 new CodeAttributeArgument("IsForeignKey", new CodePrimitiveExpression(parent.IsForeignKey))), 1378 new CodeAttributeDeclaration("DebuggerNonUserCode"), 1379 }, 1380 }; 1381 parentMembers.Add(property); 1382 property.GetStatements.Add(new CodeMethodReturnStatement( 1383 new CodePropertyReferenceExpression( 1384 new CodeFieldReferenceExpression(thisReference, storageField), 1385 "Entity"))); 1386 1387 // algorithm is: 1388 // 1.1. must be different than previous value 1389 // 1.2. or HasLoadedOrAssignedValue is false (but why?) 1390 // 2. implementations before change 1391 // 3. if previous value not null 1392 // 3.1. place parent in temp variable 1393 // 3.2. set [Storage].Entity to null 1394 // 3.3. remove it from parent list 1395 // 4. assign value to [Storage].Entity 1396 // 5. if value is not null 1397 // 5.1. add it to parent list 1398 // 5.2. set FK members with entity keys 1399 // 6. else 1400 // 6.1. set FK members to defaults (null or 0) 1401 // 7. implementationas after change 1402 var otherAssociation = schema.GetReverseAssociation(parent); 1403 var parentEntity = new CodePropertyReferenceExpression( 1404 new CodeFieldReferenceExpression(thisReference, storageField), 1405 "Entity"); 1406 var parentTable = schema.Tables.Single(t => t.Type.Associations.Contains(parent)); 1407 var childKeys = parent.TheseKeys.ToArray(); 1408 var childColumns = (from ck in childKeys select table.Type.Columns.Single(c => c.Member == ck)) 1409 .ToArray(); 1410 var parentKeys = parent.OtherKeys.ToArray(); 1411 property.SetStatements.Add(new CodeConditionStatement( 1412 // 1.1 1413 ValuesAreNotEqual_Ref(parentEntity, new CodePropertySetValueReferenceExpression()), 1414 // 2. TODO: code before the change 1415 // 3. 1416 new CodeConditionStatement( 1417 ValueIsNotNull(parentEntity), 1418 // 3.1 1419 new CodeVariableDeclarationStatement(parentType, "previous" + parent.Type, parentEntity), 1420 // 3.2 1421 new CodeAssignStatement(parentEntity, new CodePrimitiveExpression(null)), 1422 // 3.3 1423 new CodeExpressionStatement( 1424 new CodeMethodInvokeExpression( 1425 new CodeMethodReferenceExpression( 1426 new CodePropertyReferenceExpression( 1427 new CodeVariableReferenceExpression("previous" + parent.Type), 1428 otherAssociation.Member), 1429 "Remove"), 1430 thisReference))), 1431 // 4. 1432 new CodeAssignStatement(parentEntity, new CodePropertySetValueReferenceExpression()), 1433 // 5. if value is null or not... 1434 new CodeConditionStatement( 1435 ValueIsNotNull(new CodePropertySetValueReferenceExpression()), 1436 // 5.1 1437 new CodeStatement[]{ 1438 new CodeExpressionStatement( 1439 new CodeMethodInvokeExpression( 1440 new CodeMethodReferenceExpression( 1441 new CodePropertyReferenceExpression( 1442 new CodePropertySetValueReferenceExpression(), 1443 otherAssociation.Member), 1444 "Add"), 1445 thisReference)) 1446 // 5.2 1447 }.Concat(Enumerable.Range(0, parentKeys.Length).Select(i => 1448 (CodeStatement) new CodeAssignStatement( 1449 new CodeVariableReferenceExpression(GetStorageFieldName(childColumns[i])), 1450 new CodePropertyReferenceExpression( 1451 new CodePropertySetValueReferenceExpression(), 1452 parentKeys[i])) 1453 )).ToArray(), 1454 // 6. 1455 Enumerable.Range(0, parentKeys.Length).Select(i => { 1456 var column = parentTable.Type.Columns.Single(c => c.Member == childKeys[i]); 1457 return (CodeStatement) new CodeAssignStatement( 1458 new CodeVariableReferenceExpression(GetStorageFieldName(childColumns[i])), 1459 column.CanBeNull 1460 ? (CodeExpression) new CodePrimitiveExpression(null) 1461 : (CodeExpression) new CodeDefaultValueExpression(new CodeTypeReference(column.Type))); 1462 }).ToArray()) 1463 // 7: TODO 1464 )); 1465 } 1466 1467 if (parentMembers.Count == 0) 1468 return; 1469 parentMembers.First().StartDirectives.Add(new CodeRegionDirective(CodeRegionMode.Start, "Parents")); 1470 parentMembers.Last().EndDirectives.Add(new CodeRegionDirective(CodeRegionMode.End, null)); 1471 entity.Members.AddRange(parentMembers.ToArray()); 1472 } 1473 } 1474 } 1475