1 /*
2  * Copyright 2016 WebAssembly Community Group participants
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 //
18 // Instruments code to check for incorrect heap access. This checks
19 // for dereferencing 0 (null pointer access), reading past the valid
20 // top of sbrk()-addressible memory, and incorrect alignment notation.
21 //
22 
23 #include "asm_v_wasm.h"
24 #include "asmjs/shared-constants.h"
25 #include "ir/bits.h"
26 #include "ir/import-utils.h"
27 #include "ir/load-utils.h"
28 #include "pass.h"
29 #include "wasm-builder.h"
30 #include "wasm.h"
31 
32 namespace wasm {
33 
34 static const Name DYNAMICTOP_PTR_IMPORT("DYNAMICTOP_PTR");
35 static const Name GET_SBRK_PTR("emscripten_get_sbrk_ptr");
36 static const Name SBRK("sbrk");
37 static const Name SEGFAULT_IMPORT("segfault");
38 static const Name ALIGNFAULT_IMPORT("alignfault");
39 
getLoadName(Load * curr)40 static Name getLoadName(Load* curr) {
41   std::string ret = "SAFE_HEAP_LOAD_";
42   ret += curr->type.toString();
43   ret += "_" + std::to_string(curr->bytes) + "_";
44   if (LoadUtils::isSignRelevant(curr) && !curr->signed_) {
45     ret += "U_";
46   }
47   if (curr->isAtomic) {
48     ret += "A";
49   } else {
50     ret += std::to_string(curr->align);
51   }
52   return ret;
53 }
54 
getStoreName(Store * curr)55 static Name getStoreName(Store* curr) {
56   std::string ret = "SAFE_HEAP_STORE_";
57   ret += curr->valueType.toString();
58   ret += "_" + std::to_string(curr->bytes) + "_";
59   if (curr->isAtomic) {
60     ret += "A";
61   } else {
62     ret += std::to_string(curr->align);
63   }
64   return ret;
65 }
66 
67 struct AccessInstrumenter : public WalkerPass<PostWalker<AccessInstrumenter>> {
68   // If the getSbrkPtr function is implemented in the wasm, we must not
69   // instrument that, as it would lead to infinite recursion of it calling
70   // SAFE_HEAP_LOAD that calls it and so forth.
71   Name getSbrkPtr;
72 
isFunctionParallelwasm::AccessInstrumenter73   bool isFunctionParallel() override { return true; }
74 
createwasm::AccessInstrumenter75   AccessInstrumenter* create() override {
76     return new AccessInstrumenter(getSbrkPtr);
77   }
78 
AccessInstrumenterwasm::AccessInstrumenter79   AccessInstrumenter(Name getSbrkPtr) : getSbrkPtr(getSbrkPtr) {}
80 
visitLoadwasm::AccessInstrumenter81   void visitLoad(Load* curr) {
82     if (getFunction()->name == getSbrkPtr || curr->type == Type::unreachable) {
83       return;
84     }
85     Builder builder(*getModule());
86     replaceCurrent(
87       builder.makeCall(getLoadName(curr),
88                        {curr->ptr, builder.makeConstPtr(curr->offset.addr)},
89                        curr->type));
90   }
91 
visitStorewasm::AccessInstrumenter92   void visitStore(Store* curr) {
93     if (getFunction()->name == getSbrkPtr || curr->type == Type::unreachable) {
94       return;
95     }
96     Builder builder(*getModule());
97     replaceCurrent(builder.makeCall(
98       getStoreName(curr),
99       {curr->ptr, builder.makeConstPtr(curr->offset.addr), curr->value},
100       Type::none));
101   }
102 };
103 
104 struct SafeHeap : public Pass {
105   PassOptions options;
106 
runwasm::SafeHeap107   void run(PassRunner* runner, Module* module) override {
108     options = runner->options;
109     // add imports
110     addImports(module);
111     // instrument loads and stores
112     AccessInstrumenter(getSbrkPtr).run(runner, module);
113     // add helper checking funcs and imports
114     addGlobals(module, module->features);
115   }
116 
117   Name dynamicTopPtr, getSbrkPtr, sbrk, segfault, alignfault;
118 
addImportswasm::SafeHeap119   void addImports(Module* module) {
120     ImportInfo info(*module);
121     auto indexType = module->memory.indexType;
122     // Older emscripten imports env.DYNAMICTOP_PTR.
123     // Newer emscripten imports or exports emscripten_get_sbrk_ptr().
124     if (auto* existing = info.getImportedGlobal(ENV, DYNAMICTOP_PTR_IMPORT)) {
125       dynamicTopPtr = existing->name;
126     } else if (auto* existing = info.getImportedFunction(ENV, GET_SBRK_PTR)) {
127       getSbrkPtr = existing->name;
128     } else if (auto* existing = module->getExportOrNull(GET_SBRK_PTR)) {
129       getSbrkPtr = existing->value;
130     } else if (auto* existing = info.getImportedFunction(ENV, SBRK)) {
131       sbrk = existing->name;
132     } else {
133       auto* import = new Function;
134       import->name = getSbrkPtr = GET_SBRK_PTR;
135       import->module = ENV;
136       import->base = GET_SBRK_PTR;
137       import->sig = Signature(Type::none, indexType);
138       module->addFunction(import);
139     }
140     if (auto* existing = info.getImportedFunction(ENV, SEGFAULT_IMPORT)) {
141       segfault = existing->name;
142     } else {
143       auto* import = new Function;
144       import->name = segfault = SEGFAULT_IMPORT;
145       import->module = ENV;
146       import->base = SEGFAULT_IMPORT;
147       import->sig = Signature(Type::none, Type::none);
148       module->addFunction(import);
149     }
150     if (auto* existing = info.getImportedFunction(ENV, ALIGNFAULT_IMPORT)) {
151       alignfault = existing->name;
152     } else {
153       auto* import = new Function;
154       import->name = alignfault = ALIGNFAULT_IMPORT;
155       import->module = ENV;
156       import->base = ALIGNFAULT_IMPORT;
157       import->sig = Signature(Type::none, Type::none);
158       module->addFunction(import);
159     }
160   }
161 
162   bool
isPossibleAtomicOperationwasm::SafeHeap163   isPossibleAtomicOperation(Index align, Index bytes, bool shared, Type type) {
164     return align == bytes && shared && type.isInteger();
165   }
166 
addGlobalswasm::SafeHeap167   void addGlobals(Module* module, FeatureSet features) {
168     // load funcs
169     Load load;
170     for (Type type : {Type::i32, Type::i64, Type::f32, Type::f64, Type::v128}) {
171       if (type == Type::v128 && !features.hasSIMD()) {
172         continue;
173       }
174       load.type = type;
175       for (Index bytes : {1, 2, 4, 8, 16}) {
176         load.bytes = bytes;
177         if (bytes > type.getByteSize() || (type == Type::f32 && bytes != 4) ||
178             (type == Type::f64 && bytes != 8) ||
179             (type == Type::v128 && bytes != 16)) {
180           continue;
181         }
182         for (auto signed_ : {true, false}) {
183           load.signed_ = signed_;
184           if (type.isFloat() && signed_) {
185             continue;
186           }
187           for (Index align : {1, 2, 4, 8, 16}) {
188             load.align = align;
189             if (align > bytes) {
190               continue;
191             }
192             for (auto isAtomic : {true, false}) {
193               load.isAtomic = isAtomic;
194               if (isAtomic && !isPossibleAtomicOperation(
195                                 align, bytes, module->memory.shared, type)) {
196                 continue;
197               }
198               addLoadFunc(load, module);
199             }
200           }
201         }
202       }
203     }
204     // store funcs
205     Store store;
206     for (Type valueType :
207          {Type::i32, Type::i64, Type::f32, Type::f64, Type::v128}) {
208       if (valueType == Type::v128 && !features.hasSIMD()) {
209         continue;
210       }
211       store.valueType = valueType;
212       store.type = Type::none;
213       for (Index bytes : {1, 2, 4, 8, 16}) {
214         store.bytes = bytes;
215         if (bytes > valueType.getByteSize() ||
216             (valueType == Type::f32 && bytes != 4) ||
217             (valueType == Type::f64 && bytes != 8) ||
218             (valueType == Type::v128 && bytes != 16)) {
219           continue;
220         }
221         for (Index align : {1, 2, 4, 8, 16}) {
222           store.align = align;
223           if (align > bytes) {
224             continue;
225           }
226           for (auto isAtomic : {true, false}) {
227             store.isAtomic = isAtomic;
228             if (isAtomic && !isPossibleAtomicOperation(
229                               align, bytes, module->memory.shared, valueType)) {
230               continue;
231             }
232             addStoreFunc(store, module);
233           }
234         }
235       }
236     }
237   }
238 
239   // creates a function for a particular style of load
addLoadFuncwasm::SafeHeap240   void addLoadFunc(Load style, Module* module) {
241     auto name = getLoadName(&style);
242     if (module->getFunctionOrNull(name)) {
243       return;
244     }
245     auto* func = new Function;
246     func->name = name;
247     // pointer, offset
248     auto indexType = module->memory.indexType;
249     func->sig = Signature({indexType, indexType}, style.type);
250     func->vars.push_back(indexType); // pointer + offset
251     Builder builder(*module);
252     auto* block = builder.makeBlock();
253     block->list.push_back(builder.makeLocalSet(
254       2,
255       builder.makeBinary(module->memory.is64() ? AddInt64 : AddInt32,
256                          builder.makeLocalGet(0, indexType),
257                          builder.makeLocalGet(1, indexType))));
258     // check for reading past valid memory: if pointer + offset + bytes
259     block->list.push_back(
260       makeBoundsCheck(style.type, builder, 2, style.bytes, module));
261     // check proper alignment
262     if (style.align > 1) {
263       block->list.push_back(makeAlignCheck(style.align, builder, 2, module));
264     }
265     // do the load
266     auto* load = module->allocator.alloc<Load>();
267     *load = style; // basically the same as the template we are given!
268     load->ptr = builder.makeLocalGet(2, indexType);
269     Expression* last = load;
270     if (load->isAtomic && load->signed_) {
271       // atomic loads cannot be signed, manually sign it
272       last = Bits::makeSignExt(load, load->bytes, *module);
273       load->signed_ = false;
274     }
275     block->list.push_back(last);
276     block->finalize(style.type);
277     func->body = block;
278     module->addFunction(func);
279   }
280 
281   // creates a function for a particular type of store
addStoreFuncwasm::SafeHeap282   void addStoreFunc(Store style, Module* module) {
283     auto name = getStoreName(&style);
284     if (module->getFunctionOrNull(name)) {
285       return;
286     }
287     auto* func = new Function;
288     func->name = name;
289     // pointer, offset, value
290     auto indexType = module->memory.indexType;
291     func->sig = Signature({indexType, indexType, style.valueType}, Type::none);
292     func->vars.push_back(indexType); // pointer + offset
293     Builder builder(*module);
294     auto* block = builder.makeBlock();
295     block->list.push_back(builder.makeLocalSet(
296       3,
297       builder.makeBinary(module->memory.is64() ? AddInt64 : AddInt32,
298                          builder.makeLocalGet(0, indexType),
299                          builder.makeLocalGet(1, indexType))));
300     // check for reading past valid memory: if pointer + offset + bytes
301     block->list.push_back(
302       makeBoundsCheck(style.valueType, builder, 3, style.bytes, module));
303     // check proper alignment
304     if (style.align > 1) {
305       block->list.push_back(makeAlignCheck(style.align, builder, 3, module));
306     }
307     // do the store
308     auto* store = module->allocator.alloc<Store>();
309     *store = style; // basically the same as the template we are given!
310     store->ptr = builder.makeLocalGet(3, indexType);
311     store->value = builder.makeLocalGet(2, style.valueType);
312     block->list.push_back(store);
313     block->finalize(Type::none);
314     func->body = block;
315     module->addFunction(func);
316   }
317 
318   Expression*
makeAlignCheckwasm::SafeHeap319   makeAlignCheck(Address align, Builder& builder, Index local, Module* module) {
320     auto indexType = module->memory.indexType;
321     Expression* ptrBits = builder.makeLocalGet(local, indexType);
322     if (module->memory.is64()) {
323       ptrBits = builder.makeUnary(WrapInt64, ptrBits);
324     }
325     return builder.makeIf(
326       builder.makeBinary(
327         AndInt32, ptrBits, builder.makeConst(int32_t(align - 1))),
328       builder.makeCall(alignfault, {}, Type::none));
329   }
330 
makeBoundsCheckwasm::SafeHeap331   Expression* makeBoundsCheck(
332     Type type, Builder& builder, Index local, Index bytes, Module* module) {
333     auto indexType = module->memory.indexType;
334     auto upperOp = module->memory.is64()
335                      ? options.lowMemoryUnused ? LtUInt64 : EqInt64
336                      : options.lowMemoryUnused ? LtUInt32 : EqInt32;
337     auto upperBound = options.lowMemoryUnused ? PassOptions::LowMemoryBound : 0;
338     Expression* brkLocation;
339     if (sbrk.is()) {
340       brkLocation =
341         builder.makeCall(sbrk, {builder.makeConstPtr(0)}, indexType);
342     } else {
343       Expression* sbrkPtr;
344       if (dynamicTopPtr.is()) {
345         sbrkPtr = builder.makeGlobalGet(dynamicTopPtr, indexType);
346       } else {
347         sbrkPtr = builder.makeCall(getSbrkPtr, {}, indexType);
348       }
349       auto size = module->memory.is64() ? 8 : 4;
350       brkLocation = builder.makeLoad(size, false, 0, size, sbrkPtr, indexType);
351     }
352     auto gtuOp = module->memory.is64() ? GtUInt64 : GtUInt32;
353     auto addOp = module->memory.is64() ? AddInt64 : AddInt32;
354     return builder.makeIf(
355       builder.makeBinary(
356         OrInt32,
357         builder.makeBinary(upperOp,
358                            builder.makeLocalGet(local, indexType),
359                            builder.makeConstPtr(upperBound)),
360         builder.makeBinary(
361           gtuOp,
362           builder.makeBinary(addOp,
363                              builder.makeLocalGet(local, indexType),
364                              builder.makeConstPtr(bytes)),
365           brkLocation)),
366       builder.makeCall(segfault, {}, Type::none));
367   }
368 };
369 
createSafeHeapPass()370 Pass* createSafeHeapPass() { return new SafeHeap(); }
371 
372 } // namespace wasm
373