1 //
2 // PerformRewrite.cs
3 //
4 // Authors:
5 //	Chris Bacon (chrisbacon76@gmail.com)
6 //
7 // Copyright (C) 2010 Chris Bacon
8 //
9 // Permission is hereby granted, free of charge, to any person obtaining
10 // a copy of this software and associated documentation files (the
11 // "Software"), to deal in the Software without restriction, including
12 // without limitation the rights to use, copy, modify, merge, publish,
13 // distribute, sublicense, and/or sell copies of the Software, and to
14 // permit persons to whom the Software is furnished to do so, subject to
15 // the following conditions:
16 //
17 // The above copyright notice and this permission notice shall be
18 // included in all copies or substantial portions of the Software.
19 //
20 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
21 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
22 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
23 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
24 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
25 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
26 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
27 //
28 
29 using System;
30 using System.Collections.Generic;
31 using System.Linq;
32 using System.Text;
33 using Mono.Cecil;
34 using Mono.Cecil.Cil;
35 using Mono.CodeContracts.Rewrite.Ast;
36 using Mono.CodeContracts.Rewrite.AstVisitors;
37 
38 namespace Mono.CodeContracts.Rewrite {
39 	class PerformRewrite {
40 
PerformRewrite(RewriterOptions options)41 		public PerformRewrite (RewriterOptions options)
42 		{
43 			this.options = options;
44 		}
45 
46 		private RewriterOptions options;
47 		private Dictionary<MethodDefinition, TransformContractsVisitor> rewrittenMethods = new Dictionary<MethodDefinition, TransformContractsVisitor> ();
48 
Rewrite(AssemblyDefinition assembly)49 		public void Rewrite (AssemblyDefinition assembly)
50 		{
51 			foreach (ModuleDefinition module in assembly.Modules) {
52 				ContractsRuntime contractsRuntime = new ContractsRuntime(module, this.options);
53 
54 				var allMethods =
55 					from type in module.Types
56 					from method in type.Methods
57 					select method;
58 
59 				foreach (MethodDefinition method in allMethods.ToArray ()) {
60 					this.RewriteMethod (module, method, contractsRuntime);
61 				}
62 			}
63 		}
64 
RewriteMethod(ModuleDefinition module, MethodDefinition method, ContractsRuntime contractsRuntime)65 		private void RewriteMethod (ModuleDefinition module, MethodDefinition method, ContractsRuntime contractsRuntime)
66 		{
67 			if (this.rewrittenMethods.ContainsKey (method)) {
68 				return;
69 			}
70 			var overridden = this.GetOverriddenMethod (method);
71 			if (overridden != null) {
72 				this.RewriteMethod (module, overridden, contractsRuntime);
73 			}
74 			bool anyRewrites = false;
75 			var baseMethod = this.GetBaseOverriddenMethod (method);
76 			if (baseMethod != method) {
77 				// Contract inheritance must be used
78 				var vOverriddenTransform = this.rewrittenMethods [baseMethod];
79 				// Can be null if overriding an abstract method
80 				if (vOverriddenTransform != null) {
81 					if (this.options.Level >= 2) {
82 						// Only insert re-written contracts if level >= 2
83 						foreach (var inheritedRequires in vOverriddenTransform.ContractRequiresInfo) {
84 							this.RewriteIL (method.Body, null, null, inheritedRequires.RewrittenExpr);
85 							anyRewrites = true;
86 						}
87 					}
88 				}
89 			}
90 
91 			TransformContractsVisitor vTransform = null;
92 			if (method.HasBody) {
93 				vTransform = this.TransformContracts (module, method, contractsRuntime);
94 				if (vTransform.ContractRequiresInfo.Any ()) {
95 					anyRewrites = true;
96 				}
97 			}
98 			this.rewrittenMethods.Add (method, vTransform);
99 
100 			if (anyRewrites) {
101 				Console.WriteLine (method);
102 			}
103 		}
104 
TransformContracts(ModuleDefinition module, MethodDefinition method, ContractsRuntime contractsRuntime)105 		private TransformContractsVisitor TransformContracts (ModuleDefinition module, MethodDefinition method, ContractsRuntime contractsRuntime)
106 		{
107 			var body = method.Body;
108 			Decompile decompile = new Decompile (module, method);
109 			var decomp = decompile.Go ();
110 
111 			TransformContractsVisitor vTransform = new TransformContractsVisitor (module, method, decompile.Instructions, contractsRuntime);
112 			vTransform.Visit (decomp);
113 
114 			foreach (var replacement in vTransform.ContractRequiresInfo) {
115 				// Only insert re-written contracts if level >= 2
116 				Expr rewritten = this.options.Level >= 2 ? replacement.RewrittenExpr : null;
117 				this.RewriteIL (body, decompile.Instructions, replacement.OriginalExpr, rewritten);
118 			}
119 
120 			return vTransform;
121 		}
122 
RewriteIL(MethodBody body, Dictionary<Expr,Instruction> instructionLookup, Expr remove, Expr insert)123 		private void RewriteIL (MethodBody body, Dictionary<Expr,Instruction> instructionLookup, Expr remove, Expr insert)
124 		{
125 			var il = body.GetILProcessor ();
126 			Instruction instInsertBefore;
127 			if (remove != null) {
128 				var vInstExtent = new InstructionExtentVisitor (instructionLookup);
129 				vInstExtent.Visit (remove);
130 				instInsertBefore = vInstExtent.Instructions.Last ().Next;
131 				foreach (var instRemove in vInstExtent.Instructions) {
132 					il.Remove (instRemove);
133 				}
134 			} else {
135 				instInsertBefore = body.Instructions [0];
136 			}
137 			if (insert != null) {
138 				var compiler = new CompileVisitor (il, instructionLookup, inst => il.InsertBefore (instInsertBefore, inst));
139 				compiler.Visit (insert);
140 			}
141 		}
142 
GetOverriddenMethod(MethodDefinition method)143 		private MethodDefinition GetOverriddenMethod (MethodDefinition method)
144 		{
145 			if (method.IsNewSlot || !method.IsVirtual) {
146 				return null;
147 			}
148 			var baseType = method.DeclaringType.BaseType;
149 			if (baseType == null) {
150 				return null;
151 			}
152 			var overridden = baseType.Resolve ().Methods.FirstOrDefault (x => x.Name == method.Name);
153 			return overridden;
154 		}
155 
GetBaseOverriddenMethod(MethodDefinition method)156 		private MethodDefinition GetBaseOverriddenMethod (MethodDefinition method)
157 		{
158 			var overridden = method;
159 			while (true) {
160 				var overriddenTemp = this.GetOverriddenMethod (overridden);
161 				if (overriddenTemp == null) {
162 					return overridden;
163 				}
164 				overridden = overriddenTemp;
165 			}
166 		}
167 
168 	}
169 }
170