1 // 2 // PerformRewrite.cs 3 // 4 // Authors: 5 // Chris Bacon (chrisbacon76@gmail.com) 6 // 7 // Copyright (C) 2010 Chris Bacon 8 // 9 // Permission is hereby granted, free of charge, to any person obtaining 10 // a copy of this software and associated documentation files (the 11 // "Software"), to deal in the Software without restriction, including 12 // without limitation the rights to use, copy, modify, merge, publish, 13 // distribute, sublicense, and/or sell copies of the Software, and to 14 // permit persons to whom the Software is furnished to do so, subject to 15 // the following conditions: 16 // 17 // The above copyright notice and this permission notice shall be 18 // included in all copies or substantial portions of the Software. 19 // 20 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 21 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 22 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 23 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 24 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 25 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 26 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 27 // 28 29 using System; 30 using System.Collections.Generic; 31 using System.Linq; 32 using System.Text; 33 using Mono.Cecil; 34 using Mono.Cecil.Cil; 35 using Mono.CodeContracts.Rewrite.Ast; 36 using Mono.CodeContracts.Rewrite.AstVisitors; 37 38 namespace Mono.CodeContracts.Rewrite { 39 class PerformRewrite { 40 PerformRewrite(RewriterOptions options)41 public PerformRewrite (RewriterOptions options) 42 { 43 this.options = options; 44 } 45 46 private RewriterOptions options; 47 private Dictionary<MethodDefinition, TransformContractsVisitor> rewrittenMethods = new Dictionary<MethodDefinition, TransformContractsVisitor> (); 48 Rewrite(AssemblyDefinition assembly)49 public void Rewrite (AssemblyDefinition assembly) 50 { 51 foreach (ModuleDefinition module in assembly.Modules) { 52 ContractsRuntime contractsRuntime = new ContractsRuntime(module, this.options); 53 54 var allMethods = 55 from type in module.Types 56 from method in type.Methods 57 select method; 58 59 foreach (MethodDefinition method in allMethods.ToArray ()) { 60 this.RewriteMethod (module, method, contractsRuntime); 61 } 62 } 63 } 64 RewriteMethod(ModuleDefinition module, MethodDefinition method, ContractsRuntime contractsRuntime)65 private void RewriteMethod (ModuleDefinition module, MethodDefinition method, ContractsRuntime contractsRuntime) 66 { 67 if (this.rewrittenMethods.ContainsKey (method)) { 68 return; 69 } 70 var overridden = this.GetOverriddenMethod (method); 71 if (overridden != null) { 72 this.RewriteMethod (module, overridden, contractsRuntime); 73 } 74 bool anyRewrites = false; 75 var baseMethod = this.GetBaseOverriddenMethod (method); 76 if (baseMethod != method) { 77 // Contract inheritance must be used 78 var vOverriddenTransform = this.rewrittenMethods [baseMethod]; 79 // Can be null if overriding an abstract method 80 if (vOverriddenTransform != null) { 81 if (this.options.Level >= 2) { 82 // Only insert re-written contracts if level >= 2 83 foreach (var inheritedRequires in vOverriddenTransform.ContractRequiresInfo) { 84 this.RewriteIL (method.Body, null, null, inheritedRequires.RewrittenExpr); 85 anyRewrites = true; 86 } 87 } 88 } 89 } 90 91 TransformContractsVisitor vTransform = null; 92 if (method.HasBody) { 93 vTransform = this.TransformContracts (module, method, contractsRuntime); 94 if (vTransform.ContractRequiresInfo.Any ()) { 95 anyRewrites = true; 96 } 97 } 98 this.rewrittenMethods.Add (method, vTransform); 99 100 if (anyRewrites) { 101 Console.WriteLine (method); 102 } 103 } 104 TransformContracts(ModuleDefinition module, MethodDefinition method, ContractsRuntime contractsRuntime)105 private TransformContractsVisitor TransformContracts (ModuleDefinition module, MethodDefinition method, ContractsRuntime contractsRuntime) 106 { 107 var body = method.Body; 108 Decompile decompile = new Decompile (module, method); 109 var decomp = decompile.Go (); 110 111 TransformContractsVisitor vTransform = new TransformContractsVisitor (module, method, decompile.Instructions, contractsRuntime); 112 vTransform.Visit (decomp); 113 114 foreach (var replacement in vTransform.ContractRequiresInfo) { 115 // Only insert re-written contracts if level >= 2 116 Expr rewritten = this.options.Level >= 2 ? replacement.RewrittenExpr : null; 117 this.RewriteIL (body, decompile.Instructions, replacement.OriginalExpr, rewritten); 118 } 119 120 return vTransform; 121 } 122 RewriteIL(MethodBody body, Dictionary<Expr,Instruction> instructionLookup, Expr remove, Expr insert)123 private void RewriteIL (MethodBody body, Dictionary<Expr,Instruction> instructionLookup, Expr remove, Expr insert) 124 { 125 var il = body.GetILProcessor (); 126 Instruction instInsertBefore; 127 if (remove != null) { 128 var vInstExtent = new InstructionExtentVisitor (instructionLookup); 129 vInstExtent.Visit (remove); 130 instInsertBefore = vInstExtent.Instructions.Last ().Next; 131 foreach (var instRemove in vInstExtent.Instructions) { 132 il.Remove (instRemove); 133 } 134 } else { 135 instInsertBefore = body.Instructions [0]; 136 } 137 if (insert != null) { 138 var compiler = new CompileVisitor (il, instructionLookup, inst => il.InsertBefore (instInsertBefore, inst)); 139 compiler.Visit (insert); 140 } 141 } 142 GetOverriddenMethod(MethodDefinition method)143 private MethodDefinition GetOverriddenMethod (MethodDefinition method) 144 { 145 if (method.IsNewSlot || !method.IsVirtual) { 146 return null; 147 } 148 var baseType = method.DeclaringType.BaseType; 149 if (baseType == null) { 150 return null; 151 } 152 var overridden = baseType.Resolve ().Methods.FirstOrDefault (x => x.Name == method.Name); 153 return overridden; 154 } 155 GetBaseOverriddenMethod(MethodDefinition method)156 private MethodDefinition GetBaseOverriddenMethod (MethodDefinition method) 157 { 158 var overridden = method; 159 while (true) { 160 var overriddenTemp = this.GetOverriddenMethod (overridden); 161 if (overriddenTemp == null) { 162 return overridden; 163 } 164 overridden = overriddenTemp; 165 } 166 } 167 168 } 169 } 170