1 #region MIT license
2 //
3 // MIT license
4 //
5 // Copyright (c) 2007-2008 Jiri Moudry, Pascal Craponne
6 //
7 // Permission is hereby granted, free of charge, to any person obtaining a copy
8 // of this software and associated documentation files (the "Software"), to deal
9 // in the Software without restriction, including without limitation the rights
10 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 // copies of the Software, and to permit persons to whom the Software is
12 // furnished to do so, subject to the following conditions:
13 //
14 // The above copyright notice and this permission notice shall be included in
15 // all copies or substantial portions of the Software.
16 //
17 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 // THE SOFTWARE.
24 //
25 #endregion
26 using System;
27 using System.CodeDom;
28 using System.CodeDom.Compiler;
29 using System.Collections.Generic;
30 using System.ComponentModel;
31 using System.Data;
32 using System.Data.Linq.Mapping;
33 using System.IO;
34 using System.Linq;
35 using System.Reflection;
36 using System.Text;
37 using System.Text.RegularExpressions;
38 
39 using Microsoft.CSharp;
40 using Microsoft.VisualBasic;
41 
42 using DbLinq.Schema.Dbml;
43 using DbLinq.Schema.Dbml.Adapter;
44 using DbLinq.Util;
45 
46 namespace DbMetal.Generator
47 {
48 #if !MONO_STRICT
49     public
50 #endif
51     class CodeDomGenerator : ICodeGenerator
52     {
53         CodeDomProvider Provider { get; set; }
54 
55         // Provided only for Processor.EnumerateCodeGenerators().  DO NOT USE.
CodeDomGenerator()56         public CodeDomGenerator()
57         {
58         }
59 
CodeDomGenerator(CodeDomProvider provider)60         public CodeDomGenerator(CodeDomProvider provider)
61         {
62             this.Provider = provider;
63         }
64 
65         public string LanguageCode {
66             get { return "*"; }
67         }
68 
69         public string Extension {
70             get { return "*"; }
71         }
72 
CreateFromFileExtension(string extension)73         public static CodeDomGenerator CreateFromFileExtension(string extension)
74         {
75             return CreateFromLanguage(CodeDomProvider.GetLanguageFromExtension(extension));
76         }
77 
CreateFromLanguage(string language)78         public static CodeDomGenerator CreateFromLanguage(string language)
79         {
80             return new CodeDomGenerator(CodeDomProvider.CreateProvider(language));
81         }
82 
Write(TextWriter textWriter, Database dbSchema, GenerationContext context)83         public void Write(TextWriter textWriter, Database dbSchema, GenerationContext context)
84         {
85             Context = context;
86             Provider.CreateGenerator(textWriter).GenerateCodeFromNamespace(
87                 GenerateCodeDomModel(dbSchema), textWriter,
88                 new CodeGeneratorOptions() {
89                     BracingStyle = "C",
90                     IndentString = "\t",
91                 });
92         }
93 
Warning(string format, params object[] args)94         static void Warning(string format, params object[] args)
95         {
96             Console.Error.Write(Path.GetFileName(Environment.GetCommandLineArgs()[0]));
97             Console.Error.Write(": warning: ");
98             Console.Error.WriteLine(format, args);
99         }
100 
CreatePartialMethod(string methodName, params CodeParameterDeclarationExpression[] parameters)101         private CodeTypeMember CreatePartialMethod(string methodName, params CodeParameterDeclarationExpression[] parameters)
102         {
103             string prototype = null;
104             if (Provider is CSharpCodeProvider)
105             {
106                 prototype =
107                     "\t\tpartial void {0}({1});" + Environment.NewLine +
108                     "\t\t";
109             }
110             else if (Provider is VBCodeProvider)
111             {
112                 prototype =
113                     "\t\tPartial Private Sub {0}({1})" + Environment.NewLine +
114                     "\t\tEnd Sub" + Environment.NewLine +
115                     "\t\t";
116             }
117 
118             if (prototype == null)
119             {
120                 var method = new CodeMemberMethod() {
121                     Name = methodName,
122                 };
123                 method.Parameters.AddRange(parameters);
124                 return method;
125             }
126 
127             var methodDecl = new StringWriter();
128             var gen = Provider.CreateGenerator(methodDecl);
129 
130             bool comma = false;
131             foreach (var p in parameters)
132             {
133                 if (comma)
134                     methodDecl.Write(", ");
135                 comma = true;
136                 gen.GenerateCodeFromExpression(p, methodDecl, null);
137             }
138             return new CodeSnippetTypeMember(string.Format(prototype, methodName, methodDecl.ToString()));
139         }
140 
141         CodeThisReferenceExpression thisReference = new CodeThisReferenceExpression();
142 
143         protected GenerationContext Context { get; set; }
144 
GenerateCodeDomModel(Database database)145         protected virtual CodeNamespace GenerateCodeDomModel(Database database)
146         {
147             CodeNamespace _namespace = new CodeNamespace(Context.Parameters.Namespace ?? database.ContextNamespace);
148 
149             _namespace.Imports.Add(new CodeNamespaceImport("System"));
150             _namespace.Imports.Add(new CodeNamespaceImport("System.ComponentModel"));
151 #if MONO_STRICT
152             _namespace.Imports.Add(new CodeNamespaceImport("System.Data"));
153             _namespace.Imports.Add(new CodeNamespaceImport("System.Data.Linq"));
154             _namespace.Imports.Add(new CodeNamespaceImport("System.Data.Linq.Mapping"));
155 #else
156             AddConditionalImports(_namespace.Imports,
157                 "System.Data",
158                 "MONO_STRICT",
159                 new[] { "System.Data.Linq" },
160                 new[] { "DbLinq.Data.Linq", "DbLinq.Vendor" },
161                 "System.Data.Linq.Mapping");
162 #endif
163             _namespace.Imports.Add(new CodeNamespaceImport("System.Diagnostics"));
164 
165             var time = Context.Parameters.GenerateTimestamps ? DateTime.Now.ToString("u") : "[TIMESTAMP]";
166             var header = new CodeCommentStatement(GenerateCommentBanner(database, time));
167             _namespace.Comments.Add(header);
168 
169             _namespace.Types.Add(GenerateContextClass(database));
170 #if !MONO_STRICT
171             _namespace.Types.Add(GenerateMonoStrictContextConstructors(database));
172             _namespace.Types.Add(GenerateNotMonoStrictContextConstructors(database));
173 #endif
174 
175             foreach (Table table in database.Tables)
176                 _namespace.Types.Add(GenerateTableClass(table, database));
177             return _namespace;
178         }
179 
AddConditionalImports(CodeNamespaceImportCollection imports, string firstImport, string conditional, string[] importsIfTrue, string[] importsIfFalse, string lastImport)180         void AddConditionalImports(CodeNamespaceImportCollection imports,
181                 string firstImport,
182                 string conditional,
183                 string[] importsIfTrue,
184                 string[] importsIfFalse,
185                 string lastImport)
186         {
187             if (Provider is CSharpCodeProvider)
188             {
189                 // HACK HACK HACK
190                 // Would be better if CodeDom actually supported conditional compilation constructs...
191                 // This is predecated upon CSharpCodeGenerator.GenerateNamespaceImport() being implemented as:
192                 //      output.Write ("using ");
193                 //      output.Write (GetSafeName (import.Namespace));
194                 //      output.WriteLine (';');
195                 // Thus, with "crafty" execution of the namespace, we can stuff arbitrary text in there...
196 
197                 var block = new StringBuilder();
198                 // No 'using', as GenerateNamespaceImport() writes it.
199                 block.Append(firstImport).Append(";").Append(Environment.NewLine);
200                 block.Append("#if ").Append(conditional).Append(Environment.NewLine);
201                 foreach (var ns in importsIfTrue)
202                     block.Append("\tusing ").Append(ns).Append(";").Append(Environment.NewLine);
203                 block.Append("#else   // ").Append(conditional).Append(Environment.NewLine);
204                 foreach (var ns in importsIfFalse)
205                     block.Append("\tusing ").Append(ns).Append(";").Append(Environment.NewLine);
206                 block.Append("#endif  // ").Append(conditional).Append(Environment.NewLine);
207                 block.Append("\tusing ").Append(lastImport);
208                 // No ';', as GenerateNamespaceImport() writes it.
209 
210                 imports.Add(new CodeNamespaceImport(block.ToString()));
211             }
212             else if (Provider is VBCodeProvider)
213             {
214                 // HACK HACK HACK
215                 // Would be better if CodeDom actually supported conditional compilation constructs...
216                 // This is predecated upon VBCodeGenerator.GenerateNamespaceImport() being implemented as:
217                 //      output.Write ("Imports ");
218                 //      output.Write (import.Namespace);
219                 //      output.WriteLine ();
220                 // Thus, with "crafty" execution of the namespace, we can stuff arbitrary text in there...
221 
222                 var block = new StringBuilder();
223                 // No 'Imports', as GenerateNamespaceImport() writes it.
224                 block.Append(firstImport).Append(Environment.NewLine);
225                 block.Append("#If ").Append(conditional).Append(" Then").Append(Environment.NewLine);
226                 foreach (var ns in importsIfTrue)
227                     block.Append("Imports ").Append(ns).Append(Environment.NewLine);
228                 block.Append("#Else     ' ").Append(conditional).Append(Environment.NewLine);
229                 foreach (var ns in importsIfFalse)
230                     block.Append("Imports ").Append(ns).Append(Environment.NewLine);
231                 block.Append("#End If   ' ").Append(conditional).Append(Environment.NewLine);
232                 block.Append("Imports ").Append(lastImport);
233                 // No newline, as GenerateNamespaceImport() writes it.
234 
235                 imports.Add(new CodeNamespaceImport(block.ToString()));
236             }
237             else
238             {
239                 // Default to using the DbLinq imports
240                 imports.Add(new CodeNamespaceImport(firstImport));
241                 foreach (var ns in importsIfTrue)
242                     imports.Add(new CodeNamespaceImport(ns));
243                 imports.Add(new CodeNamespaceImport(lastImport));
244             }
245         }
246 
GenerateCommentBanner(Database database, string time)247         private string GenerateCommentBanner(Database database, string time)
248         {
249             var result = new StringBuilder();
250 
251             // http://www.network-science.de/ascii/
252             // http://www.network-science.de/ascii/ascii.php?TEXT=MetalSequel&x=14&y=14&FONT=_all+fonts+with+your+text_&RICH=no&FORM=left&STRE=no&WIDT=80
253             result.Append(
254                 @"
255   ____  _     __  __      _        _
256  |  _ \| |__ |  \/  | ___| |_ __ _| |
257  | | | | '_ \| |\/| |/ _ \ __/ _` | |
258  | |_| | |_) | |  | |  __/ || (_| | |
259  |____/|_.__/|_|  |_|\___|\__\__,_|_|
260 
261 ");
262             result.AppendLine(String.Format(" Auto-generated from {0} on {1}.", database.Name, time));
263             result.AppendLine(" Please visit http://code.google.com/p/dblinq2007/ for more information.");
264 
265             return result.ToString();
266         }
267 
GenerateContextClass(Database database)268         protected virtual CodeTypeDeclaration GenerateContextClass(Database database)
269         {
270             var _class = new CodeTypeDeclaration() {
271                 IsClass         = true,
272                 IsPartial       = true,
273                 Name            = database.Class,
274                 TypeAttributes  = TypeAttributes.Public
275             };
276 
277             _class.BaseTypes.Add(GetContextBaseType(database.BaseType));
278 
279             var onCreated = CreatePartialMethod("OnCreated");
280             onCreated.StartDirectives.Add(new CodeRegionDirective(CodeRegionMode.Start, "Extensibility Method Declarations"));
281             onCreated.EndDirectives.Add(new CodeRegionDirective(CodeRegionMode.End, null));
282             _class.Members.Add(onCreated);
283 
284             // Implement Constructor
285             GenerateContextConstructors(_class, database);
286 
287             foreach (Table table in database.Tables)
288             {
289                 var tableType = new CodeTypeReference(table.Type.Name);
290                 var property = new CodeMemberProperty() {
291                     Attributes  = MemberAttributes.Public | MemberAttributes.Final,
292                     Name        = table.Member,
293                     Type        = new CodeTypeReference("Table", tableType),
294                 };
295                 property.GetStatements.Add(
296                     new CodeMethodReturnStatement(
297                         new CodeMethodInvokeExpression(
298                             new CodeMethodReferenceExpression(thisReference, "GetTable", tableType))));
299                 _class.Members.Add(property);
300             }
301 
302             foreach (var function in database.Functions)
303             {
304                 GenerateContextFunction(_class, function);
305             }
306 
307             return _class;
308         }
309 
GetContextBaseType(string type)310         static string GetContextBaseType(string type)
311         {
312             string baseType = "DataContext";
313 
314             if (!string.IsNullOrEmpty(type))
315             {
316                 var t = TypeLoader.Load(type);
317                 if (t != null)
318                     baseType = t.Name;
319             }
320 
321             return baseType;
322         }
323 
GenerateContextConstructors(CodeTypeDeclaration contextType, Database database)324         void GenerateContextConstructors(CodeTypeDeclaration contextType, Database database)
325         {
326             // .ctor(string connectionString);
327             var constructor = new CodeConstructor() {
328                 Attributes = MemberAttributes.Public,
329                 Parameters = { new CodeParameterDeclarationExpression(typeof(string), "connectionString") },
330             };
331             constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("connectionString"));
332             constructor.Statements.Add(new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "OnCreated")));
333             contextType.Members.Add(constructor);
334 
335 #if MONO_STRICT
336             // .ctor(IDbConnection connection);
337             constructor = new CodeConstructor() {
338                 Attributes = MemberAttributes.Public,
339                 Parameters = { new CodeParameterDeclarationExpression("IDbConnection", "connection") },
340             };
341             constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("connection"));
342             constructor.Statements.Add(new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "OnCreated")));
343             contextType.Members.Add(constructor);
344 #endif
345 
346             // .ctor(string connection, MappingSource mappingSource);
347             constructor = new CodeConstructor() {
348                 Attributes = MemberAttributes.Public,
349                 Parameters = {
350                     new CodeParameterDeclarationExpression(typeof(string), "connection"),
351                     new CodeParameterDeclarationExpression("MappingSource", "mappingSource"),
352                 },
353             };
354             constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("connection"));
355             constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("mappingSource"));
356             constructor.Statements.Add(new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "OnCreated")));
357             contextType.Members.Add(constructor);
358 
359             // .ctor(IDbConnection connection, MappingSource mappingSource);
360             constructor = new CodeConstructor() {
361                 Attributes = MemberAttributes.Public,
362                 Parameters = {
363                     new CodeParameterDeclarationExpression("IDbConnection", "connection"),
364                     new CodeParameterDeclarationExpression("MappingSource", "mappingSource"),
365                 },
366             };
367             constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("connection"));
368             constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("mappingSource"));
369             constructor.Statements.Add(new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "OnCreated")));
370             contextType.Members.Add(constructor);
371         }
372 
GenerateMonoStrictContextConstructors(Database database)373         CodeTypeDeclaration GenerateMonoStrictContextConstructors(Database database)
374         {
375             var contextType = new CodeTypeDeclaration()
376             {
377                 IsClass         = true,
378                 IsPartial       = true,
379                 Name            = database.Class,
380                 TypeAttributes  = TypeAttributes.Public
381             };
382             AddConditionalIfElseBlocks(contextType, "MONO_STRICT");
383 
384             // .ctor(IDbConnection connection);
385             var constructor = new CodeConstructor() {
386                 Attributes = MemberAttributes.Public,
387                 Parameters = { new CodeParameterDeclarationExpression("IDbConnection", "connection") },
388             };
389             constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("connection"));
390             constructor.Statements.Add(new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "OnCreated")));
391             contextType.Members.Add(constructor);
392 
393             return contextType;
394         }
395 
AddConditionalIfElseBlocks(CodeTypeMember member, string condition)396         void AddConditionalIfElseBlocks(CodeTypeMember member, string condition)
397         {
398             string startIf = null, elseIf = null;
399             if (Provider is CSharpCodeProvider)
400             {
401                 startIf = string.Format("Start {0}{1}#if {0}{1}", condition, Environment.NewLine);
402                 elseIf  = string.Format("End {0}{1}\t#endregion{1}#else     // {0}", condition, Environment.NewLine);
403             }
404             if (Provider is VBCodeProvider)
405             {
406                 startIf = string.Format("Start {0}\"{1}#If {0} Then{1}    '", condition, Environment.NewLine);
407                 elseIf  = string.Format("End {0}\"{1}\t#End Region{1}#Else     ' {0}", condition, Environment.NewLine);
408             }
409             if (startIf != null && elseIf != null)
410             {
411                 member.StartDirectives.Add(new CodeRegionDirective(CodeRegionMode.Start, startIf));
412                 member.EndDirectives.Add(new CodeRegionDirective(CodeRegionMode.Start, elseIf));
413             }
414         }
415 
AddConditionalEndifBlocks(CodeTypeMember member, string condition)416         void AddConditionalEndifBlocks(CodeTypeMember member, string condition)
417         {
418             string endIf = null;
419             if (Provider is CSharpCodeProvider)
420             {
421                 endIf   = string.Format("End Not {0}{1}\t#endregion{1}#endif     // {0}", condition, Environment.NewLine);
422             }
423             if (Provider is VBCodeProvider)
424             {
425                 endIf   = string.Format("End Not {0}\"{1}\t#End Region{1}#End If     ' {0}", condition, Environment.NewLine);
426             }
427             if (endIf != null)
428             {
429                 member.EndDirectives.Add(new CodeRegionDirective(CodeRegionMode.Start, endIf));
430                 member.EndDirectives.Add(new CodeRegionDirective(CodeRegionMode.End, null));
431             }
432         }
433 
GenerateNotMonoStrictContextConstructors(Database database)434         CodeTypeDeclaration GenerateNotMonoStrictContextConstructors(Database database)
435         {
436             var contextType = new CodeTypeDeclaration() {
437                 IsClass         = true,
438                 IsPartial       = true,
439                 Name            = database.Class,
440                 TypeAttributes  = TypeAttributes.Public
441             };
442             AddConditionalEndifBlocks(contextType, "MONO_STRICT");
443 
444             // .ctor(IDbConnection connection);
445             var constructor = new CodeConstructor() {
446                 Attributes = MemberAttributes.Public,
447                 Parameters = { new CodeParameterDeclarationExpression("IDbConnection", "connection") },
448             };
449             constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("connection"));
450             constructor.BaseConstructorArgs.Add(new CodeObjectCreateExpression(Context.SchemaLoader.Vendor.GetType()));
451             constructor.Statements.Add(new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "OnCreated")));
452             contextType.Members.Add(constructor);
453 
454             // .ctor(IDbConnection connection, IVendor mappingSource);
455             constructor = new CodeConstructor() {
456                 Attributes = MemberAttributes.Public,
457                 Parameters = {
458                     new CodeParameterDeclarationExpression("IDbConnection", "connection"),
459                     new CodeParameterDeclarationExpression("IVendor", "sqlDialect"),
460                 },
461             };
462             constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("connection"));
463             constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("sqlDialect"));
464             constructor.Statements.Add(new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "OnCreated")));
465             contextType.Members.Add(constructor);
466 
467             // .ctor(IDbConnection connection, MappingSource mappingSource, IVendor mappingSource);
468             constructor = new CodeConstructor() {
469                 Attributes = MemberAttributes.Public,
470                 Parameters = {
471                     new CodeParameterDeclarationExpression("IDbConnection", "connection"),
472                     new CodeParameterDeclarationExpression("MappingSource", "mappingSource"),
473                     new CodeParameterDeclarationExpression("IVendor", "sqlDialect"),
474                 },
475             };
476             constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("connection"));
477             constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("mappingSource"));
478             constructor.BaseConstructorArgs.Add(new CodeArgumentReferenceExpression("sqlDialect"));
479             constructor.Statements.Add(new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "OnCreated")));
480             contextType.Members.Add(constructor);
481 
482             return contextType;
483         }
484 
GenerateContextFunction(CodeTypeDeclaration contextType, Function function)485         void GenerateContextFunction(CodeTypeDeclaration contextType, Function function)
486         {
487             if (function == null || string.IsNullOrEmpty(function.Name))
488             {
489                 Warning("L33 Invalid storedProcdure object: missing name.");
490                 return;
491             }
492 
493             var methodRetType = GetFunctionReturnType(function);
494             var method = new CodeMemberMethod() {
495                 Attributes  = ToMemberAttributes(function),
496                 Name        = function.Method ?? function.Name,
497                 ReturnType  = methodRetType,
498                 CustomAttributes = {
499                     new CodeAttributeDeclaration("Function",
500                         new CodeAttributeArgument("Name", new CodePrimitiveExpression(function.Name)),
501                         new CodeAttributeArgument("IsComposable", new CodePrimitiveExpression(function.IsComposable))),
502                 },
503             };
504             if (method.Parameters != null)
505                 method.Parameters.AddRange(function.Parameters.Select(x => GetFunctionParameterType(x)).ToArray());
506             if (function.Return != null && !string.IsNullOrEmpty(function.Return.DbType))
507                 method.ReturnTypeCustomAttributes.Add(
508                         new CodeAttributeDeclaration("Parameter",
509                             new CodeAttributeArgument("DbType", new CodePrimitiveExpression(function.Return.DbType))));
510 
511             contextType.Members.Add(method);
512 
513             for (int i = 0; i < function.Parameters.Count; ++i)
514             {
515                 var p = function.Parameters[i];
516                 if (!p.DirectionOut)
517                     continue;
518                 method.Statements.Add(
519                         new CodeAssignStatement(
520                             new CodeVariableReferenceExpression(p.Name),
521                             new CodeDefaultValueExpression(new CodeTypeReference(p.Type))));
522             }
523 
524             var executeMethodCallArgs = new List<CodeExpression>() {
525                 thisReference,
526                 new CodeCastExpression(
527                     new CodeTypeReference("System.Reflection.MethodInfo"),
528                     new CodeMethodInvokeExpression(
529                         new CodeMethodReferenceExpression(
530                             new CodeTypeReferenceExpression("System.Reflection.MethodBase"), "GetCurrentMethod"))),
531             };
532             if (method.Parameters != null)
533                 executeMethodCallArgs.AddRange(
534                         function.Parameters.Select(p => (CodeExpression) new CodeVariableReferenceExpression(p.Name)));
535             method.Statements.Add(
536                     new CodeVariableDeclarationStatement(
537                         new CodeTypeReference("IExecuteResult"),
538                         "result",
539                         new CodeMethodInvokeExpression(
540                             new CodeMethodReferenceExpression(thisReference, "ExecuteMethodCall"),
541                             executeMethodCallArgs.ToArray())));
542             for (int i = 0; i < function.Parameters.Count; ++i)
543             {
544                 var p = function.Parameters[i];
545                 if (!p.DirectionOut)
546                     continue;
547                 method.Statements.Add(
548                         new CodeAssignStatement(
549                             new CodeVariableReferenceExpression(p.Name),
550                             new CodeCastExpression(
551                                 new CodeTypeReference(p.Type),
552                                 new CodeMethodInvokeExpression(
553                                     new CodeMethodReferenceExpression(
554                                         new CodeVariableReferenceExpression("result"),
555                                         "GetParameterValue"),
556                                     new CodePrimitiveExpression(i)))));
557             }
558 
559             if (methodRetType != null)
560             {
561                 method.Statements.Add(
562                         new CodeMethodReturnStatement(
563                             new CodeCastExpression(
564                                 method.ReturnType,
565                                 new CodePropertyReferenceExpression(
566                                     new CodeVariableReferenceExpression("result"),
567                                     "ReturnValue"))));
568             }
569         }
570 
GetFunctionReturnType(Function function)571         CodeTypeReference GetFunctionReturnType(Function function)
572         {
573             CodeTypeReference type = null;
574             if (function.Return != null)
575             {
576                 type = GetFunctionType(function.Return.Type);
577             }
578 
579             bool isDataShapeUnknown = function.ElementType == null
580                                       && function.BodyContainsSelectStatement
581                                       && !function.IsComposable;
582             if (isDataShapeUnknown)
583             {
584                 //if we don't know the shape of results, and the proc body contains some selects,
585                 //we have no choice but to return an untyped DataSet.
586                 //
587                 //TODO: either parse proc body like microsoft,
588                 //or create a little GUI tool which would call the proc with test values, to determine result shape.
589                 type = new CodeTypeReference(typeof(DataSet));
590             }
591             return type;
592         }
593 
GetFunctionType(string type)594         static CodeTypeReference GetFunctionType(string type)
595         {
596             var t = System.Type.GetType(type);
597             if (t == null)
598                 return new CodeTypeReference(type);
599             if (t.IsValueType)
600                 return new CodeTypeReference(typeof(Nullable<>)) {
601                     TypeArguments = {
602                         new CodeTypeReference(t),
603                     },
604                 };
605             return new CodeTypeReference(t);
606         }
607 
GetFunctionParameterType(Parameter parameter)608         CodeParameterDeclarationExpression GetFunctionParameterType(Parameter parameter)
609         {
610             var p = new CodeParameterDeclarationExpression(GetFunctionType(parameter.Type), parameter.Name) {
611                 CustomAttributes = {
612                     new CodeAttributeDeclaration("Parameter",
613                         new CodeAttributeArgument("Name", new CodePrimitiveExpression(parameter.Name)),
614                         new CodeAttributeArgument("DbType", new CodePrimitiveExpression(parameter.DbType))),
615                 },
616             };
617             switch (parameter.Direction)
618             {
619                 case DbLinq.Schema.Dbml.ParameterDirection.In:
620                     p.Direction = FieldDirection.In;
621                     break;
622                 case DbLinq.Schema.Dbml.ParameterDirection.Out:
623                     p.Direction = FieldDirection.Out;
624                     break;
625                 case DbLinq.Schema.Dbml.ParameterDirection.InOut:
626                     p.Direction = FieldDirection.In | FieldDirection.Out;
627                     break;
628                 default:
629                     throw new ArgumentOutOfRangeException();
630             }
631             return p;
632         }
633 
GenerateTableClass(Table table, Database database)634         protected CodeTypeDeclaration GenerateTableClass(Table table, Database database)
635         {
636             var _class = new CodeTypeDeclaration() {
637                 IsClass         = true,
638                 IsPartial       = true,
639                 Name            = table.Type.Name,
640                 TypeAttributes  = TypeAttributes.Public,
641                 CustomAttributes = {
642                     new CodeAttributeDeclaration("Table",
643                         new CodeAttributeArgument("Name", new CodePrimitiveExpression(table.Name))),
644                 },
645             };
646 
647             WriteCustomTypes(_class, table);
648 
649             var havePrimaryKeys = table.Type.Columns.Any(c => c.IsPrimaryKey);
650             if (havePrimaryKeys)
651             {
652                 GenerateINotifyPropertyChanging(_class);
653                 GenerateINotifyPropertyChanged(_class);
654             }
655 
656             // Implement Constructor
657             var constructor = new CodeConstructor() { Attributes = MemberAttributes.Public };
658             // children are EntitySet
659             foreach (var child in GetClassChildren(table))
660             {
661                 // if the association has a storage, we use it. Otherwise, we use the property name
662                 var entitySetMember = GetStorageFieldName(child);
663                 constructor.Statements.Add(
664                     new CodeAssignStatement(
665                         new CodeVariableReferenceExpression(entitySetMember),
666                         new CodeObjectCreateExpression(
667                             new CodeTypeReference("EntitySet", new CodeTypeReference(child.Type)),
668                             new CodeDelegateCreateExpression(
669                                 new CodeTypeReference("Action", new CodeTypeReference(child.Type)),
670                                 thisReference, child.Member + "_Attach"),
671                             new CodeDelegateCreateExpression(
672                                 new CodeTypeReference("Action", new CodeTypeReference(child.Type)),
673                                 thisReference, child.Member + "_Detach"))));
674             }
675             constructor.Statements.Add(new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "OnCreated")));
676             _class.Members.Add(constructor);
677 
678             if (Context.Parameters.GenerateEqualsHash)
679             {
680                 GenerateEntityGetHashCodeAndEquals(_class, table);
681             }
682 
683             GenerateExtensibilityDeclarations(_class, table);
684 
685             // todo: add these when the actually get called
686             //partial void OnLoaded();
687             //partial void OnValidate(System.Data.Linq.ChangeAction action);
688 
689             // columns
690             foreach (Column column in table.Type.Columns)
691             {
692                 var relatedAssociations = from a in table.Type.Associations
693                                           where a.IsForeignKey && a.TheseKeys.Contains(column.Name)
694                                           select a;
695 
696                 var type = ToCodeTypeReference(column);
697                 var columnMember = column.Member ?? column.Name;
698 
699                 var field = new CodeMemberField(type, GetStorageFieldName(column));
700                 _class.Members.Add(field);
701                 var fieldReference = new CodeFieldReferenceExpression(new CodeThisReferenceExpression(), field.Name);
702 
703                 var onChanging  = GetChangingMethodName(columnMember);
704                 var onChanged   = GetChangedMethodName(columnMember);
705 
706                 var property = new CodeMemberProperty();
707                 property.Type = type;
708                 property.Name = columnMember;
709                 property.Attributes = MemberAttributes.Public | MemberAttributes.Final;
710 
711                 var defAttrValues = new ColumnAttribute();
712                 var args = new List<CodeAttributeArgument>() {
713                     new CodeAttributeArgument("Storage", new CodePrimitiveExpression(GetStorageFieldName(column))),
714                     new CodeAttributeArgument("Name", new CodePrimitiveExpression(column.Name)),
715                     new CodeAttributeArgument("DbType", new CodePrimitiveExpression(column.DbType)),
716                 };
717                 if (defAttrValues.IsPrimaryKey != column.IsPrimaryKey)
718                     args.Add(new CodeAttributeArgument("IsPrimaryKey", new CodePrimitiveExpression(column.IsPrimaryKey)));
719                 if (defAttrValues.IsDbGenerated != column.IsDbGenerated)
720                     args.Add(new CodeAttributeArgument("IsDbGenerated", new CodePrimitiveExpression(column.IsDbGenerated)));
721                 if (column.AutoSync != DbLinq.Schema.Dbml.AutoSync.Default)
722                     args.Add(new CodeAttributeArgument("AutoSync",
723                         new CodeFieldReferenceExpression(new CodeTypeReferenceExpression("AutoSync"), column.AutoSync.ToString())));
724                 if (defAttrValues.CanBeNull != column.CanBeNull)
725                     args.Add(new CodeAttributeArgument("CanBeNull", new CodePrimitiveExpression(column.CanBeNull)));
726                 if (column.Expression != null)
727                     args.Add(new CodeAttributeArgument("Expression", new CodePrimitiveExpression(column.Expression)));
728                 property.CustomAttributes.Add(
729                     new CodeAttributeDeclaration("Column", args.ToArray()));
730                 property.CustomAttributes.Add(new CodeAttributeDeclaration("DebuggerNonUserCode"));
731 
732                 property.GetStatements.Add(new CodeMethodReturnStatement(fieldReference));
733 
734                 var whenUpdating = new List<CodeStatement>(
735                     from assoc in relatedAssociations
736                     select (CodeStatement) new CodeConditionStatement(
737                         new CodePropertyReferenceExpression(
738                             new CodeVariableReferenceExpression(GetStorageFieldName(assoc)),
739                             "HasLoadedOrAssignedValue"),
740                         new CodeThrowExceptionStatement(
741                             new CodeObjectCreateExpression(typeof(System.Data.Linq.ForeignKeyReferenceAlreadyHasValueException)))));
742                 whenUpdating.Add(
743                         new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, onChanging, new CodePropertySetValueReferenceExpression())));
744                 if (havePrimaryKeys)
745                     whenUpdating.Add(
746                             new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "SendPropertyChanging")));
747                 whenUpdating.Add(
748                         new CodeAssignStatement(fieldReference, new CodePropertySetValueReferenceExpression()));
749                 if (havePrimaryKeys)
750                     whenUpdating.Add(
751                             new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, "SendPropertyChanged", new CodePrimitiveExpression(property.Name))));
752                 whenUpdating.Add(
753                         new CodeExpressionStatement(new CodeMethodInvokeExpression(thisReference, onChanged)));
754 
755                 var fieldType = TypeLoader.Load(column.Type);
756                 // This is needed for VB.NET generation;
757                 // int/string/etc. can use '<>' for comparison, but NOT arrays and other reference types.
758                 // arrays/etc. require the 'Is' operator, which is CodeBinaryOperatorType.IdentityEquality.
759                 // The VB IsNot operator is not exposed from CodeDom.
760                 // Thus, we need to special-case: if fieldType is a ref or nullable type,
761                 //  generate '(field Is value) = false'; otherwise,
762                 //  generate '(field <> value)'
763                 CodeBinaryOperatorExpression condition = fieldType.IsClass || fieldType.IsNullable()
764                     ? ValuesAreNotEqual_Ref(new CodeVariableReferenceExpression(field.Name), new CodePropertySetValueReferenceExpression())
765                     : ValuesAreNotEqual(new CodeVariableReferenceExpression(field.Name), new CodePropertySetValueReferenceExpression());
766                 property.SetStatements.Add(new CodeConditionStatement(condition, whenUpdating.ToArray()));
767                 _class.Members.Add(property);
768             }
769 
770             GenerateEntityChildren(_class, table, database);
771             GenerateEntityChildrenAttachment(_class, table, database);
772             GenerateEntityParents(_class, table, database);
773 
774             return _class;
775         }
776 
WriteCustomTypes(CodeTypeDeclaration entity, Table table)777         void WriteCustomTypes(CodeTypeDeclaration entity, Table table)
778         {
779             // detect required custom types
780             foreach (var column in table.Type.Columns)
781             {
782                 var extendedType = column.ExtendedType;
783                 var enumType = extendedType as EnumType;
784                 if (enumType != null)
785                 {
786                     Context.ExtendedTypes[column] = new GenerationContext.ExtendedTypeAndName {
787                         Type = column.ExtendedType,
788                         Table = table
789                     };
790                 }
791             }
792 
793             var customTypesNames = new List<string>();
794 
795             // create names and avoid conflits
796             foreach (var extendedTypePair in Context.ExtendedTypes)
797             {
798                 if (extendedTypePair.Value.Table != table)
799                     continue;
800 
801                 if (string.IsNullOrEmpty(extendedTypePair.Value.Type.Name))
802                 {
803                     string name = extendedTypePair.Key.Member + "Type";
804                     for (; ; )
805                     {
806                         if ((from t in Context.ExtendedTypes.Values where t.Type.Name == name select t).FirstOrDefault() == null)
807                         {
808                             extendedTypePair.Value.Type.Name = name;
809                             break;
810                         }
811                         // at 3rd loop, it will look ugly, however we will never go there
812                         name = extendedTypePair.Value.Table.Type.Name + name;
813                     }
814                 }
815                 customTypesNames.Add(extendedTypePair.Value.Type.Name);
816             }
817 
818             // write custom types
819             if (customTypesNames.Count > 0)
820             {
821                 var customTypes = new List<CodeTypeDeclaration>(customTypesNames.Count);
822 
823                 foreach (var extendedTypePair in Context.ExtendedTypes)
824                 {
825                     if (extendedTypePair.Value.Table != table)
826                         continue;
827 
828                     var extendedType = extendedTypePair.Value.Type;
829                     var enumValue = extendedType as EnumType;
830 
831                     if (enumValue != null)
832                     {
833                         var enumType = new CodeTypeDeclaration(enumValue.Name) {
834                             TypeAttributes = TypeAttributes.Public,
835                             IsEnum = true,
836                         };
837                         customTypes.Add(enumType);
838                         var orderedValues = from nv in enumValue orderby nv.Value select nv;
839                         int currentValue = 1;
840                         foreach (var nameValue in orderedValues)
841                         {
842                             var field = new CodeMemberField() {
843                                 Name = nameValue.Key,
844                             };
845                             enumType.Members.Add(field);
846                             if (nameValue.Value != currentValue)
847                             {
848                                 currentValue = nameValue.Value;
849                                 field.InitExpression = new CodePrimitiveExpression(nameValue.Value);
850                             }
851                             currentValue++;
852                         }
853                     }
854                 }
855 
856                 if (customTypes.Count == 0)
857                     return;
858                 customTypes.First().StartDirectives.Add(new CodeRegionDirective(CodeRegionMode.Start,
859                         string.Format("Custom type definitions for {0}", string.Join(", ", customTypesNames.ToArray()))));
860                 customTypes.Last().EndDirectives.Add(new CodeRegionDirective(CodeRegionMode.End, null));
861                 entity.Members.AddRange(customTypes.ToArray());
862             }
863         }
864 
GenerateExtensibilityDeclarations(CodeTypeDeclaration entity, Table table)865         void GenerateExtensibilityDeclarations(CodeTypeDeclaration entity, Table table)
866         {
867             var partialMethods = new[] { CreatePartialMethod("OnCreated") }
868                 .Concat(table.Type.Columns.Select(c => new[] { CreateChangedMethodDecl(c), CreateChangingMethodDecl(c) })
869                     .SelectMany(md => md)).ToArray();
870             partialMethods.First().StartDirectives.Add(new CodeRegionDirective(CodeRegionMode.Start, "Extensibility Method Declarations"));
871             partialMethods.Last().EndDirectives.Add(new CodeRegionDirective(CodeRegionMode.End, null));
872             entity.Members.AddRange(partialMethods);
873         }
874 
GetChangedMethodName(string columnName)875         static string GetChangedMethodName(string columnName)
876         {
877             return string.Format("On{0}Changed", columnName);
878         }
879 
CreateChangedMethodDecl(Column column)880         CodeTypeMember CreateChangedMethodDecl(Column column)
881         {
882             return CreatePartialMethod(GetChangedMethodName(column.Member));
883         }
884 
GetChangingMethodName(string columnName)885         static string GetChangingMethodName(string columnName)
886         {
887             return string.Format("On{0}Changing", columnName);
888         }
889 
CreateChangingMethodDecl(Column column)890         CodeTypeMember CreateChangingMethodDecl(Column column)
891         {
892             return CreatePartialMethod(GetChangingMethodName(column.Member),
893                     new CodeParameterDeclarationExpression(ToCodeTypeReference(column), "value"));
894         }
895 
ToCodeTypeReference(Column column)896         static CodeTypeReference ToCodeTypeReference(Column column)
897         {
898             var t = System.Type.GetType(column.Type);
899             if (t == null)
900                 return new CodeTypeReference(column.Type);
901             return t.IsValueType && column.CanBeNull
902                 ? new CodeTypeReference("System.Nullable", new CodeTypeReference(column.Type))
903                 : new CodeTypeReference(column.Type);
904         }
905 
ValuesAreNotEqual(CodeExpression a, CodeExpression b)906         CodeBinaryOperatorExpression ValuesAreNotEqual(CodeExpression a, CodeExpression b)
907         {
908             return new CodeBinaryOperatorExpression(a, CodeBinaryOperatorType.IdentityInequality, b);
909         }
910 
ValuesAreNotEqual_Ref(CodeExpression a, CodeExpression b)911         CodeBinaryOperatorExpression ValuesAreNotEqual_Ref(CodeExpression a, CodeExpression b)
912         {
913             return new CodeBinaryOperatorExpression(
914                         new CodeBinaryOperatorExpression(
915                             a,
916                             CodeBinaryOperatorType.IdentityEquality,
917                             b),
918                         CodeBinaryOperatorType.ValueEquality,
919                         new CodePrimitiveExpression(false));
920         }
921 
ValueIsNull(CodeExpression value)922         CodeBinaryOperatorExpression ValueIsNull(CodeExpression value)
923         {
924             return new CodeBinaryOperatorExpression(
925                 value,
926                 CodeBinaryOperatorType.IdentityEquality,
927                 new CodePrimitiveExpression(null));
928         }
929 
ValueIsNotNull(CodeExpression value)930         CodeBinaryOperatorExpression ValueIsNotNull(CodeExpression value)
931         {
932             return new CodeBinaryOperatorExpression(
933                 value,
934                 CodeBinaryOperatorType.IdentityInequality,
935                 new CodePrimitiveExpression(null));
936         }
937 
GetStorageFieldName(Column column)938         static string GetStorageFieldName(Column column)
939         {
940             return GetStorageFieldName(column.Storage ?? column.Member);
941         }
942 
GetStorageFieldName(string storage)943         static string GetStorageFieldName(string storage)
944         {
945             if (storage.StartsWith("_"))
946                 return storage;
947             return "_" + storage;
948         }
949 
GenerateINotifyPropertyChanging(CodeTypeDeclaration entity)950         private void GenerateINotifyPropertyChanging(CodeTypeDeclaration entity)
951         {
952             entity.BaseTypes.Add(typeof(INotifyPropertyChanging));
953             var propertyChangingEvent = new CodeMemberEvent() {
954                 Attributes  = MemberAttributes.Public,
955                 Name        = "PropertyChanging",
956                 Type        = new CodeTypeReference(typeof(PropertyChangingEventHandler)),
957                 ImplementationTypes = {
958                     new CodeTypeReference(typeof(INotifyPropertyChanging))
959                 },
960             };
961             var eventArgs = new CodeMemberField(new CodeTypeReference(typeof(PropertyChangingEventArgs)), "emptyChangingEventArgs") {
962                 Attributes      = MemberAttributes.Static | MemberAttributes.Private,
963                 InitExpression  = new CodeObjectCreateExpression(new CodeTypeReference(typeof(PropertyChangingEventArgs)),
964                     new CodePrimitiveExpression("")),
965             };
966             var method = new CodeMemberMethod() {
967                 Attributes  = MemberAttributes.Family,
968                 Name        = "SendPropertyChanging",
969             };
970             method.Statements.Add(new CodeVariableDeclarationStatement(typeof(PropertyChangingEventHandler), "h") {
971                 InitExpression  = new CodeEventReferenceExpression(thisReference, "PropertyChanging"),
972             });
973             method.Statements.Add(new CodeConditionStatement(
974                     ValueIsNotNull(new CodeVariableReferenceExpression("h")),
975                     new CodeExpressionStatement(
976                         new CodeDelegateInvokeExpression(new CodeVariableReferenceExpression("h"), thisReference, new CodeFieldReferenceExpression(null, "emptyChangingEventArgs")))));
977 
978             entity.Members.Add(propertyChangingEvent);
979             entity.Members.Add(eventArgs);
980             entity.Members.Add(method);
981         }
982 
GenerateINotifyPropertyChanged(CodeTypeDeclaration entity)983         private void GenerateINotifyPropertyChanged(CodeTypeDeclaration entity)
984         {
985             entity.BaseTypes.Add(typeof(INotifyPropertyChanged));
986 
987             var propertyChangedEvent = new CodeMemberEvent() {
988                 Attributes = MemberAttributes.Public,
989                 Name = "PropertyChanged",
990                 Type = new CodeTypeReference(typeof(PropertyChangedEventHandler)),
991                 ImplementationTypes = {
992                     new CodeTypeReference(typeof(INotifyPropertyChanged))
993                 },
994             };
995 
996             var method = new CodeMemberMethod() {
997                 Attributes = MemberAttributes.Family,
998                 Name = "SendPropertyChanged",
999                 Parameters = { new CodeParameterDeclarationExpression(typeof(System.String), "propertyName") }
1000             };
1001             method.Statements.Add(new CodeVariableDeclarationStatement(typeof(PropertyChangedEventHandler), "h") {
1002                 InitExpression = new CodeEventReferenceExpression(thisReference, "PropertyChanged"),
1003             });
1004             method.Statements.Add(new CodeConditionStatement(
1005                     ValueIsNotNull(new CodeVariableReferenceExpression("h")),
1006                     new CodeExpressionStatement(
1007                         new CodeDelegateInvokeExpression(new CodeVariableReferenceExpression("h"), thisReference, new CodeObjectCreateExpression(typeof(PropertyChangedEventArgs), new CodeVariableReferenceExpression("propertyName"))))));
1008 
1009             entity.Members.Add(propertyChangedEvent);
1010             entity.Members.Add(method);
1011         }
1012 
GenerateEntityGetHashCodeAndEquals(CodeTypeDeclaration entity, Table table)1013         void GenerateEntityGetHashCodeAndEquals(CodeTypeDeclaration entity, Table table)
1014         {
1015             var primaryKeys = table.Type.Columns.Where(c => c.IsPrimaryKey);
1016             var pkCount = primaryKeys.Count();
1017             if (pkCount == 0)
1018             {
1019                 Warning("Table {0} has no primary key(s).  Skipping /generate-equals-hash for this table.",
1020                         table.Name);
1021                 return;
1022             }
1023             entity.BaseTypes.Add(new CodeTypeReference(typeof(IEquatable<>)) {
1024                 TypeArguments = { new CodeTypeReference(entity.Name) },
1025             });
1026 
1027             var method = new CodeMemberMethod() {
1028                 Attributes  = MemberAttributes.Public | MemberAttributes.Override,
1029                 Name        = "GetHashCode",
1030                 ReturnType  = new CodeTypeReference(typeof(int)),
1031             };
1032             entity.Members.Add(method);
1033             method.Statements.Add(new CodeVariableDeclarationStatement(typeof(int), "hc", new CodePrimitiveExpression(0)));
1034             var numShifts = 32 / pkCount;
1035             int pki = 0;
1036             foreach (var pk in primaryKeys)
1037             {
1038                 var shift = 1 << (pki++ * numShifts);
1039                 // lack of exclusive-or means we instead split the 32-bit hash code value
1040                 // into pkCount "chunks", each chunk being numShifts in size.
1041                 // Thus, if there are two primary keys, the first primary key gets the
1042                 // lower 16 bits, while the second primray key gets the upper 16 bits.
1043                 CodeStatement update = new CodeAssignStatement(
1044                         new CodeVariableReferenceExpression("hc"),
1045                         new CodeBinaryOperatorExpression(
1046                             new CodeVariableReferenceExpression("hc"),
1047                             CodeBinaryOperatorType.BitwiseOr,
1048                             new CodeBinaryOperatorExpression(
1049                                 new CodeMethodInvokeExpression(
1050                                     new CodeMethodReferenceExpression(
1051                                         new CodeVariableReferenceExpression(GetStorageFieldName(pk)),
1052                                         "GetHashCode")),
1053                                 CodeBinaryOperatorType.Multiply,
1054                                 new CodePrimitiveExpression(shift))));
1055                 var pkType = System.Type.GetType(pk.Type);
1056                 if (pk.CanBeNull || (pkType != null && (pkType.IsClass || pkType.IsNullable())))
1057                 {
1058                     update = new CodeConditionStatement(
1059                             ValueIsNotNull(new CodeVariableReferenceExpression(GetStorageFieldName(pk))),
1060                             update);
1061                 }
1062                 method.Statements.Add(update);
1063             }
1064             method.Statements.Add(new CodeMethodReturnStatement(new CodeVariableReferenceExpression("hc")));
1065 
1066             method = new CodeMemberMethod() {
1067                 Attributes  = MemberAttributes.Public | MemberAttributes.Override,
1068                 Name        = "Equals",
1069                 ReturnType  = new CodeTypeReference(typeof(bool)),
1070                 Parameters = {
1071                     new CodeParameterDeclarationExpression(new CodeTypeReference(typeof(object)), "value"),
1072                 },
1073             };
1074             entity.Members.Add(method);
1075             method.Statements.Add(
1076                     new CodeConditionStatement(
1077                         ValueIsNull(new CodeVariableReferenceExpression("value")),
1078                         new CodeMethodReturnStatement(new CodePrimitiveExpression(false))));
1079             method.Statements.Add(
1080                     new CodeConditionStatement(
1081                         ValuesAreNotEqual_Ref(
1082                             new CodeMethodInvokeExpression(
1083                                 new CodeMethodReferenceExpression(
1084                                     new CodeVariableReferenceExpression("value"),
1085                                     "GetType")),
1086                             new CodeMethodInvokeExpression(
1087                                 new CodeMethodReferenceExpression(thisReference, "GetType"))),
1088                         new CodeMethodReturnStatement(new CodePrimitiveExpression(false))));
1089             method.Statements.Add(
1090                     new CodeVariableDeclarationStatement(
1091                         new CodeTypeReference(entity.Name),
1092                         "other",
1093                         new CodeCastExpression(new CodeTypeReference(entity.Name), new CodeVariableReferenceExpression("value"))));
1094             method.Statements.Add(
1095                     new CodeMethodReturnStatement(
1096                         new CodeMethodInvokeExpression(
1097                             new CodeMethodReferenceExpression(thisReference, "Equals"),
1098                             new CodeVariableReferenceExpression("other"))));
1099 
1100             method = new CodeMemberMethod() {
1101                 Attributes  = MemberAttributes.Public,
1102                 Name        = "Equals",
1103                 ReturnType  = new CodeTypeReference(typeof(bool)),
1104                 Parameters  = {
1105                     new CodeParameterDeclarationExpression(new CodeTypeReference(entity.Name), "value"),
1106                 },
1107                 ImplementationTypes = {
1108                     new CodeTypeReference("IEquatable", new CodeTypeReference(entity.Name)),
1109                 },
1110             };
1111             entity.Members.Add(method);
1112             method.Statements.Add(
1113                     new CodeConditionStatement(
1114                         ValueIsNull(new CodeVariableReferenceExpression("value")),
1115                         new CodeMethodReturnStatement(new CodePrimitiveExpression(false))));
1116 
1117             CodeExpression equals = null;
1118             foreach (var pk in primaryKeys)
1119             {
1120                 var compare = new CodeMethodInvokeExpression(
1121                         new CodeMethodReferenceExpression(
1122                             new CodePropertyReferenceExpression(
1123                                 new CodeTypeReferenceExpression(
1124                                     new CodeTypeReference("System.Collections.Generic.EqualityComparer",
1125                                         new CodeTypeReference(pk.Type))),
1126                                 "Default"),
1127                             "Equals"),
1128                         new CodeFieldReferenceExpression(thisReference, GetStorageFieldName(pk)),
1129                         new CodeFieldReferenceExpression(new CodeVariableReferenceExpression("value"), GetStorageFieldName(pk)));
1130                 equals = equals == null
1131                     ? (CodeExpression) compare
1132                     : (CodeExpression) new CodeBinaryOperatorExpression(
1133                         equals,
1134                         CodeBinaryOperatorType.BooleanAnd,
1135                         compare);
1136             }
1137             method.Statements.Add(
1138                     new CodeMethodReturnStatement(equals));
1139         }
1140 
GenerateEntityChildren(CodeTypeDeclaration entity, Table table, Database schema)1141         void GenerateEntityChildren(CodeTypeDeclaration entity, Table table, Database schema)
1142         {
1143             var children = GetClassChildren(table);
1144             if (children.Any())
1145             {
1146                 var childMembers = new List<CodeTypeMember>();
1147 
1148                 foreach (var child in children)
1149                 {
1150                     bool hasDuplicates = (from c in children where c.Member == child.Member select c).Count() > 1;
1151 
1152                     // the following is apparently useless
1153                     var targetTable = schema.Tables.FirstOrDefault(t => t.Type.Name == child.Type);
1154                     if (targetTable == null)
1155                     {
1156                         //Logger.Write(Level.Error, "ERROR L143 target table class not found:" + child.Type);
1157                         continue;
1158                     }
1159 
1160                     var childType = new CodeTypeReference("EntitySet", new CodeTypeReference(child.Type));
1161                     var storage = GetStorageFieldName(child);
1162                     entity.Members.Add(new CodeMemberField(childType, storage));
1163 
1164                     var childName = hasDuplicates
1165                         ? child.Member + "_" + string.Join("", child.OtherKeys.ToArray())
1166                         : child.Member;
1167                     var property = new CodeMemberProperty() {
1168                         Name        = childName,
1169                         Type        = childType,
1170                         Attributes  = ToMemberAttributes(child),
1171                         CustomAttributes = {
1172                             new CodeAttributeDeclaration("Association",
1173                                 new CodeAttributeArgument("Storage", new CodePrimitiveExpression(GetStorageFieldName(child))),
1174                                 new CodeAttributeArgument("OtherKey", new CodePrimitiveExpression(child.OtherKey)),
1175                                 new CodeAttributeArgument("ThisKey", new CodePrimitiveExpression(child.ThisKey)),
1176                                 new CodeAttributeArgument("Name", new CodePrimitiveExpression(child.Name))),
1177                             new CodeAttributeDeclaration("DebuggerNonUserCode"),
1178                         },
1179                     };
1180                     childMembers.Add(property);
1181                     property.GetStatements.Add(new CodeMethodReturnStatement(
1182                             new CodeFieldReferenceExpression(thisReference, storage)));
1183                     property.SetStatements.Add(new CodeAssignStatement(
1184                             new CodeFieldReferenceExpression(thisReference, storage),
1185                             new CodePropertySetValueReferenceExpression()));
1186                 }
1187 
1188                 if (childMembers.Count == 0)
1189                     return;
1190                 childMembers.First().StartDirectives.Add(new CodeRegionDirective(CodeRegionMode.Start, "Children"));
1191                 childMembers.Last().EndDirectives.Add(new CodeRegionDirective(CodeRegionMode.End, null));
1192                 entity.Members.AddRange(childMembers.ToArray());
1193             }
1194         }
1195 
GetClassChildren(Table table)1196         IEnumerable<Association> GetClassChildren(Table table)
1197         {
1198             return table.Type.Associations.Where(a => !a.IsForeignKey);
1199         }
1200 
ToMemberAttributes(Association association)1201         static MemberAttributes ToMemberAttributes(Association association)
1202         {
1203             MemberAttributes attrs = 0;
1204             if (!association.AccessModifierSpecified)
1205                 attrs |= MemberAttributes.Public;
1206             else
1207                 switch (association.AccessModifier)
1208                 {
1209                     case AccessModifier.Internal:           attrs = MemberAttributes.Assembly; break;
1210                     case AccessModifier.Private:            attrs = MemberAttributes.Private; break;
1211                     case AccessModifier.Protected:          attrs = MemberAttributes.Family; break;
1212                     case AccessModifier.ProtectedInternal:  attrs = MemberAttributes.FamilyOrAssembly; break;
1213                     case AccessModifier.Public:             attrs = MemberAttributes.Public; break;
1214                     default:
1215                         throw new ArgumentOutOfRangeException("association", "Modifier value '" + association.AccessModifierSpecified + "' is an unsupported value.");
1216                 }
1217             if (!association.ModifierSpecified)
1218                 attrs |= MemberAttributes.Final;
1219             else
1220                 switch (association.Modifier)
1221                 {
1222                     case MemberModifier.New:        attrs |= MemberAttributes.New | MemberAttributes.Final; break;
1223                     case MemberModifier.NewVirtual: attrs |= MemberAttributes.New; break;
1224                     case MemberModifier.Override:   attrs |= MemberAttributes.Override; break;
1225                     case MemberModifier.Virtual:    break;
1226                 }
1227             return attrs;
1228         }
1229 
ToMemberAttributes(Function function)1230         static MemberAttributes ToMemberAttributes(Function function)
1231         {
1232             MemberAttributes attrs = 0;
1233             if (!function.AccessModifierSpecified)
1234                 attrs |= MemberAttributes.Public;
1235             else
1236                 switch (function.AccessModifier)
1237                 {
1238                     case AccessModifier.Internal:           attrs = MemberAttributes.Assembly; break;
1239                     case AccessModifier.Private:            attrs = MemberAttributes.Private; break;
1240                     case AccessModifier.Protected:          attrs = MemberAttributes.Family; break;
1241                     case AccessModifier.ProtectedInternal:  attrs = MemberAttributes.FamilyOrAssembly; break;
1242                     case AccessModifier.Public:             attrs = MemberAttributes.Public; break;
1243                     default:
1244                         throw new ArgumentOutOfRangeException("function", "Modifier value '" + function.AccessModifierSpecified + "' is an unsupported value.");
1245                 }
1246             if (!function.ModifierSpecified)
1247                 attrs |= MemberAttributes.Final;
1248             else
1249                 switch (function.Modifier)
1250                 {
1251                     case MemberModifier.New:        attrs |= MemberAttributes.New | MemberAttributes.Final; break;
1252                     case MemberModifier.NewVirtual: attrs |= MemberAttributes.New; break;
1253                     case MemberModifier.Override:   attrs |= MemberAttributes.Override; break;
1254                     case MemberModifier.Virtual:    break;
1255                 }
1256             return attrs;
1257         }
1258 
GetStorageFieldName(Association association)1259         static string GetStorageFieldName(Association association)
1260         {
1261             return association.Storage != null
1262                 ? GetStorageFieldName(association.Storage)
1263                 : "_" + CreateIdentifier(association.Member ?? association.Name);
1264         }
1265 
CreateIdentifier(string value)1266         static string CreateIdentifier(string value)
1267         {
1268             return Regex.Replace(value, @"\W", "_");
1269         }
1270 
GenerateEntityChildrenAttachment(CodeTypeDeclaration entity, Table table, Database schema)1271         void GenerateEntityChildrenAttachment(CodeTypeDeclaration entity, Table table, Database schema)
1272         {
1273             var children = GetClassChildren(table).ToList();
1274             if (!children.Any())
1275                 return;
1276 
1277             var havePrimaryKeys = table.Type.Columns.Any(c => c.IsPrimaryKey);
1278 
1279             var handlers = new List<CodeTypeMember>();
1280 
1281             foreach (var child in children)
1282             {
1283                 // the reverse child is the association seen from the child
1284                 // we're going to use it...
1285                 var reverseChild = schema.GetReverseAssociation(child);
1286                 // ... to get the parent name
1287                 var memberName = reverseChild.Member;
1288 
1289                 var sendPropertyChanging = new CodeExpressionStatement(
1290                         new CodeMethodInvokeExpression(
1291                             new CodeMethodReferenceExpression(thisReference, "SendPropertyChanging")));
1292 
1293                 var attach = new CodeMemberMethod() {
1294                     Name = child.Member + "_Attach",
1295                     Parameters = {
1296                         new CodeParameterDeclarationExpression(child.Type, "entity"),
1297                     },
1298                 };
1299                 handlers.Add(attach);
1300                 if (havePrimaryKeys)
1301                     attach.Statements.Add(sendPropertyChanging);
1302                 attach.Statements.Add(
1303                         new CodeAssignStatement(
1304                             new CodePropertyReferenceExpression(new CodeVariableReferenceExpression("entity"), memberName),
1305                             thisReference));
1306 
1307                 var detach = new CodeMemberMethod() {
1308                     Name = child.Member + "_Detach",
1309                     Parameters = {
1310                         new CodeParameterDeclarationExpression(child.Type, "entity"),
1311                     },
1312                 };
1313                 handlers.Add(detach);
1314                 if (havePrimaryKeys)
1315                     detach.Statements.Add(sendPropertyChanging);
1316                 detach.Statements.Add(
1317                         new CodeAssignStatement(
1318                             new CodePropertyReferenceExpression(new CodeVariableReferenceExpression("entity"), memberName),
1319                             new CodePrimitiveExpression(null)));
1320             }
1321 
1322             if (handlers.Count == 0)
1323                 return;
1324 
1325             handlers.First().StartDirectives.Add(new CodeRegionDirective(CodeRegionMode.Start, "Attachment handlers"));
1326             handlers.Last().EndDirectives.Add(new CodeRegionDirective(CodeRegionMode.End, null));
1327             entity.Members.AddRange(handlers.ToArray());
1328         }
1329 
GenerateEntityParents(CodeTypeDeclaration entity, Table table, Database schema)1330         void GenerateEntityParents(CodeTypeDeclaration entity, Table table, Database schema)
1331         {
1332             var parents = table.Type.Associations.Where(a => a.IsForeignKey);
1333             if (!parents.Any())
1334                 return;
1335 
1336             var parentMembers = new List<CodeTypeMember>();
1337 
1338             foreach (var parent in parents)
1339             {
1340                 bool hasDuplicates = (from p in parents where p.Member == parent.Member select p).Count() > 1;
1341                 // WriteClassParent(writer, parent, hasDuplicates, schema, context);
1342                 // the following is apparently useless
1343                 DbLinq.Schema.Dbml.Table targetTable = schema.Tables.FirstOrDefault(t => t.Type.Name == parent.Type);
1344                 if (targetTable == null)
1345                 {
1346                     //Logger.Write(Level.Error, "ERROR L191 target table type not found: " + parent.Type + "  (processing " + parent.Name + ")");
1347                     continue;
1348                 }
1349 
1350                 string member = parent.Member;
1351                 string storageField = GetStorageFieldName(parent);
1352                 // TODO: remove this
1353                 if (member == parent.ThisKey)
1354                 {
1355                     member = parent.ThisKey + targetTable.Type.Name; //repeat name to prevent collision (same as Linq)
1356                     storageField = "_x_" + parent.Member;
1357                 }
1358 
1359                 var parentType = new CodeTypeReference(targetTable.Type.Name);
1360                 entity.Members.Add(new CodeMemberField(new CodeTypeReference("EntityRef", parentType), storageField) {
1361                     InitExpression = new CodeObjectCreateExpression(new CodeTypeReference("EntityRef", parentType)),
1362                 });
1363 
1364                 var parentName = hasDuplicates
1365                     ? member + "_" + string.Join("", parent.TheseKeys.ToArray())
1366                     : member;
1367                 var property = new CodeMemberProperty() {
1368                     Name        = parentName,
1369                     Type        = parentType,
1370                     Attributes  = ToMemberAttributes(parent),
1371                     CustomAttributes = {
1372                         new CodeAttributeDeclaration("Association",
1373                             new CodeAttributeArgument("Storage", new CodePrimitiveExpression(storageField)),
1374                             new CodeAttributeArgument("OtherKey", new CodePrimitiveExpression(parent.OtherKey)),
1375                             new CodeAttributeArgument("ThisKey", new CodePrimitiveExpression(parent.ThisKey)),
1376                             new CodeAttributeArgument("Name", new CodePrimitiveExpression(parent.Name)),
1377                             new CodeAttributeArgument("IsForeignKey", new CodePrimitiveExpression(parent.IsForeignKey))),
1378                         new CodeAttributeDeclaration("DebuggerNonUserCode"),
1379                     },
1380                 };
1381                 parentMembers.Add(property);
1382                 property.GetStatements.Add(new CodeMethodReturnStatement(
1383                         new CodePropertyReferenceExpression(
1384                             new CodeFieldReferenceExpression(thisReference, storageField),
1385                             "Entity")));
1386 
1387                 // algorithm is:
1388                 // 1.1. must be different than previous value
1389                 // 1.2. or HasLoadedOrAssignedValue is false (but why?)
1390                 // 2. implementations before change
1391                 // 3. if previous value not null
1392                 // 3.1. place parent in temp variable
1393                 // 3.2. set [Storage].Entity to null
1394                 // 3.3. remove it from parent list
1395                 // 4. assign value to [Storage].Entity
1396                 // 5. if value is not null
1397                 // 5.1. add it to parent list
1398                 // 5.2. set FK members with entity keys
1399                 // 6. else
1400                 // 6.1. set FK members to defaults (null or 0)
1401                 // 7. implementationas after change
1402                 var otherAssociation = schema.GetReverseAssociation(parent);
1403                 var parentEntity = new CodePropertyReferenceExpression(
1404                         new CodeFieldReferenceExpression(thisReference, storageField),
1405                         "Entity");
1406                 var parentTable = schema.Tables.Single(t => t.Type.Associations.Contains(parent));
1407                 var childKeys = parent.TheseKeys.ToArray();
1408                 var childColumns = (from ck in childKeys select table.Type.Columns.Single(c => c.Member == ck))
1409                                     .ToArray();
1410                 var parentKeys = parent.OtherKeys.ToArray();
1411                 property.SetStatements.Add(new CodeConditionStatement(
1412                         // 1.1
1413                         ValuesAreNotEqual_Ref(parentEntity, new CodePropertySetValueReferenceExpression()),
1414                         // 2. TODO: code before the change
1415                         // 3.
1416                         new CodeConditionStatement(
1417                             ValueIsNotNull(parentEntity),
1418                             // 3.1
1419                             new CodeVariableDeclarationStatement(parentType, "previous" + parent.Type, parentEntity),
1420                             // 3.2
1421                             new CodeAssignStatement(parentEntity, new CodePrimitiveExpression(null)),
1422                             // 3.3
1423                             new CodeExpressionStatement(
1424                                  new CodeMethodInvokeExpression(
1425                                     new CodeMethodReferenceExpression(
1426                                         new CodePropertyReferenceExpression(
1427                                             new CodeVariableReferenceExpression("previous" + parent.Type),
1428                                             otherAssociation.Member),
1429                                         "Remove"),
1430                                     thisReference))),
1431                         // 4.
1432                         new CodeAssignStatement(parentEntity, new CodePropertySetValueReferenceExpression()),
1433                         // 5. if value is null or not...
1434                         new CodeConditionStatement(
1435                             ValueIsNotNull(new CodePropertySetValueReferenceExpression()),
1436                             // 5.1
1437                             new CodeStatement[]{
1438                                 new CodeExpressionStatement(
1439                                     new CodeMethodInvokeExpression(
1440                                         new CodeMethodReferenceExpression(
1441                                             new CodePropertyReferenceExpression(
1442                                                 new CodePropertySetValueReferenceExpression(),
1443                                                 otherAssociation.Member),
1444                                             "Add"),
1445                                         thisReference))
1446                             // 5.2
1447                             }.Concat(Enumerable.Range(0, parentKeys.Length).Select(i =>
1448                                 (CodeStatement) new CodeAssignStatement(
1449                                     new CodeVariableReferenceExpression(GetStorageFieldName(childColumns[i])),
1450                                     new CodePropertyReferenceExpression(
1451                                         new CodePropertySetValueReferenceExpression(),
1452                                         parentKeys[i]))
1453                             )).ToArray(),
1454                             // 6.
1455                             Enumerable.Range(0, parentKeys.Length).Select(i => {
1456                                 var column = parentTable.Type.Columns.Single(c => c.Member == childKeys[i]);
1457                                 return (CodeStatement) new CodeAssignStatement(
1458                                     new CodeVariableReferenceExpression(GetStorageFieldName(childColumns[i])),
1459                                     column.CanBeNull
1460                                         ? (CodeExpression) new CodePrimitiveExpression(null)
1461                                         : (CodeExpression) new CodeDefaultValueExpression(new CodeTypeReference(column.Type)));
1462                             }).ToArray())
1463                         // 7: TODO
1464                 ));
1465             }
1466 
1467             if (parentMembers.Count == 0)
1468                 return;
1469             parentMembers.First().StartDirectives.Add(new CodeRegionDirective(CodeRegionMode.Start, "Parents"));
1470             parentMembers.Last().EndDirectives.Add(new CodeRegionDirective(CodeRegionMode.End, null));
1471             entity.Members.AddRange(parentMembers.ToArray());
1472         }
1473     }
1474 }
1475