1 //==- WebAssemblyAsmTypeCheck.cpp - Assembler for WebAssembly -*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file is part of the WebAssembly Assembler.
11 ///
12 /// It contains code to translate a parsed .s file into MCInsts.
13 ///
14 //===----------------------------------------------------------------------===//
15
16 #include "AsmParser/WebAssemblyAsmTypeCheck.h"
17 #include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
18 #include "MCTargetDesc/WebAssemblyTargetStreamer.h"
19 #include "TargetInfo/WebAssemblyTargetInfo.h"
20 #include "Utils/WebAssemblyTypeUtilities.h"
21 #include "Utils/WebAssemblyUtilities.h"
22 #include "WebAssembly.h"
23 #include "llvm/MC/MCContext.h"
24 #include "llvm/MC/MCExpr.h"
25 #include "llvm/MC/MCInst.h"
26 #include "llvm/MC/MCInstrInfo.h"
27 #include "llvm/MC/MCParser/MCParsedAsmOperand.h"
28 #include "llvm/MC/MCParser/MCTargetAsmParser.h"
29 #include "llvm/MC/MCSectionWasm.h"
30 #include "llvm/MC/MCStreamer.h"
31 #include "llvm/MC/MCSubtargetInfo.h"
32 #include "llvm/MC/MCSymbol.h"
33 #include "llvm/MC/MCSymbolWasm.h"
34 #include "llvm/MC/TargetRegistry.h"
35 #include "llvm/Support/Compiler.h"
36 #include "llvm/Support/Endian.h"
37 #include "llvm/Support/SourceMgr.h"
38
39 using namespace llvm;
40
41 #define DEBUG_TYPE "wasm-asm-parser"
42
43 extern StringRef GetMnemonic(unsigned Opc);
44
45 namespace llvm {
46
WebAssemblyAsmTypeCheck(MCAsmParser & Parser,const MCInstrInfo & MII,bool is64)47 WebAssemblyAsmTypeCheck::WebAssemblyAsmTypeCheck(MCAsmParser &Parser,
48 const MCInstrInfo &MII, bool is64)
49 : Parser(Parser), MII(MII), is64(is64) {
50 }
51
funcDecl(const wasm::WasmSignature & Sig)52 void WebAssemblyAsmTypeCheck::funcDecl(const wasm::WasmSignature &Sig) {
53 LocalTypes.assign(Sig.Params.begin(), Sig.Params.end());
54 ReturnTypes.assign(Sig.Returns.begin(), Sig.Returns.end());
55 }
56
localDecl(const SmallVector<wasm::ValType,4> & Locals)57 void WebAssemblyAsmTypeCheck::localDecl(const SmallVector<wasm::ValType, 4> &Locals) {
58 LocalTypes.insert(LocalTypes.end(), Locals.begin(), Locals.end());
59 }
60
dumpTypeStack(Twine Msg)61 void WebAssemblyAsmTypeCheck::dumpTypeStack(Twine Msg) {
62 LLVM_DEBUG({
63 std::string s;
64 for (auto VT : Stack) {
65 s += WebAssembly::typeToString(VT);
66 s += " ";
67 }
68 dbgs() << Msg << s << '\n';
69 });
70 }
71
typeError(SMLoc ErrorLoc,const Twine & Msg)72 bool WebAssemblyAsmTypeCheck::typeError(SMLoc ErrorLoc, const Twine &Msg) {
73 // Once you get one type error in a function, it will likely trigger more
74 // which are mostly not helpful.
75 if (TypeErrorThisFunction)
76 return true;
77 // If we're currently in unreachable code, we suppress errors completely.
78 if (Unreachable)
79 return false;
80 TypeErrorThisFunction = true;
81 dumpTypeStack("current stack: ");
82 return Parser.Error(ErrorLoc, Msg);
83 }
84
popType(SMLoc ErrorLoc,std::optional<wasm::ValType> EVT)85 bool WebAssemblyAsmTypeCheck::popType(SMLoc ErrorLoc,
86 std::optional<wasm::ValType> EVT) {
87 if (Stack.empty()) {
88 return typeError(ErrorLoc,
89 EVT ? StringRef("empty stack while popping ") +
90 WebAssembly::typeToString(*EVT)
91 : StringRef("empty stack while popping value"));
92 }
93 auto PVT = Stack.pop_back_val();
94 if (EVT && *EVT != PVT) {
95 return typeError(ErrorLoc,
96 StringRef("popped ") + WebAssembly::typeToString(PVT) +
97 ", expected " + WebAssembly::typeToString(*EVT));
98 }
99 return false;
100 }
101
popRefType(SMLoc ErrorLoc)102 bool WebAssemblyAsmTypeCheck::popRefType(SMLoc ErrorLoc) {
103 if (Stack.empty()) {
104 return typeError(ErrorLoc, StringRef("empty stack while popping reftype"));
105 }
106 auto PVT = Stack.pop_back_val();
107 if (!WebAssembly::isRefType(PVT)) {
108 return typeError(ErrorLoc, StringRef("popped ") +
109 WebAssembly::typeToString(PVT) +
110 ", expected reftype");
111 }
112 return false;
113 }
114
getLocal(SMLoc ErrorLoc,const MCInst & Inst,wasm::ValType & Type)115 bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCInst &Inst,
116 wasm::ValType &Type) {
117 auto Local = static_cast<size_t>(Inst.getOperand(0).getImm());
118 if (Local >= LocalTypes.size())
119 return typeError(ErrorLoc, StringRef("no local type specified for index ") +
120 std::to_string(Local));
121 Type = LocalTypes[Local];
122 return false;
123 }
124
checkEnd(SMLoc ErrorLoc,bool PopVals)125 bool WebAssemblyAsmTypeCheck::checkEnd(SMLoc ErrorLoc, bool PopVals) {
126 if (LastSig.Returns.size() > Stack.size())
127 return typeError(ErrorLoc, "end: insufficient values on the type stack");
128
129 if (PopVals) {
130 for (auto VT : llvm::reverse(LastSig.Returns)) {
131 if (popType(ErrorLoc, VT))
132 return true;
133 }
134 return false;
135 }
136
137 for (size_t i = 0; i < LastSig.Returns.size(); i++) {
138 auto EVT = LastSig.Returns[i];
139 auto PVT = Stack[Stack.size() - LastSig.Returns.size() + i];
140 if (PVT != EVT)
141 return typeError(
142 ErrorLoc, StringRef("end got ") + WebAssembly::typeToString(PVT) +
143 ", expected " + WebAssembly::typeToString(EVT));
144 }
145 return false;
146 }
147
checkSig(SMLoc ErrorLoc,const wasm::WasmSignature & Sig)148 bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc,
149 const wasm::WasmSignature& Sig) {
150 for (auto VT : llvm::reverse(Sig.Params))
151 if (popType(ErrorLoc, VT)) return true;
152 Stack.insert(Stack.end(), Sig.Returns.begin(), Sig.Returns.end());
153 return false;
154 }
155
getSymRef(SMLoc ErrorLoc,const MCInst & Inst,const MCSymbolRefExpr * & SymRef)156 bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCInst &Inst,
157 const MCSymbolRefExpr *&SymRef) {
158 auto Op = Inst.getOperand(0);
159 if (!Op.isExpr())
160 return typeError(ErrorLoc, StringRef("expected expression operand"));
161 SymRef = dyn_cast<MCSymbolRefExpr>(Op.getExpr());
162 if (!SymRef)
163 return typeError(ErrorLoc, StringRef("expected symbol operand"));
164 return false;
165 }
166
getGlobal(SMLoc ErrorLoc,const MCInst & Inst,wasm::ValType & Type)167 bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCInst &Inst,
168 wasm::ValType &Type) {
169 const MCSymbolRefExpr *SymRef;
170 if (getSymRef(ErrorLoc, Inst, SymRef))
171 return true;
172 auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
173 switch (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA)) {
174 case wasm::WASM_SYMBOL_TYPE_GLOBAL:
175 Type = static_cast<wasm::ValType>(WasmSym->getGlobalType().Type);
176 break;
177 case wasm::WASM_SYMBOL_TYPE_FUNCTION:
178 case wasm::WASM_SYMBOL_TYPE_DATA:
179 switch (SymRef->getKind()) {
180 case MCSymbolRefExpr::VK_GOT:
181 case MCSymbolRefExpr::VK_WASM_GOT_TLS:
182 Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32;
183 return false;
184 default:
185 break;
186 }
187 [[fallthrough]];
188 default:
189 return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() +
190 " missing .globaltype");
191 }
192 return false;
193 }
194
getTable(SMLoc ErrorLoc,const MCInst & Inst,wasm::ValType & Type)195 bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCInst &Inst,
196 wasm::ValType &Type) {
197 const MCSymbolRefExpr *SymRef;
198 if (getSymRef(ErrorLoc, Inst, SymRef))
199 return true;
200 auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
201 if (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA) !=
202 wasm::WASM_SYMBOL_TYPE_TABLE)
203 return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() +
204 " missing .tabletype");
205 Type = static_cast<wasm::ValType>(WasmSym->getTableType().ElemType);
206 return false;
207 }
208
endOfFunction(SMLoc ErrorLoc)209 bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc) {
210 // Check the return types.
211 for (auto RVT : llvm::reverse(ReturnTypes)) {
212 if (popType(ErrorLoc, RVT))
213 return true;
214 }
215 if (!Stack.empty()) {
216 return typeError(ErrorLoc, std::to_string(Stack.size()) +
217 " superfluous return values");
218 }
219 Unreachable = true;
220 return false;
221 }
222
typeCheck(SMLoc ErrorLoc,const MCInst & Inst,OperandVector & Operands)223 bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
224 OperandVector &Operands) {
225 auto Opc = Inst.getOpcode();
226 auto Name = GetMnemonic(Opc);
227 dumpTypeStack("typechecking " + Name + ": ");
228 wasm::ValType Type;
229 if (Name == "local.get") {
230 if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
231 return true;
232 Stack.push_back(Type);
233 } else if (Name == "local.set") {
234 if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
235 return true;
236 if (popType(ErrorLoc, Type))
237 return true;
238 } else if (Name == "local.tee") {
239 if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
240 return true;
241 if (popType(ErrorLoc, Type))
242 return true;
243 Stack.push_back(Type);
244 } else if (Name == "global.get") {
245 if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
246 return true;
247 Stack.push_back(Type);
248 } else if (Name == "global.set") {
249 if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
250 return true;
251 if (popType(ErrorLoc, Type))
252 return true;
253 } else if (Name == "table.get") {
254 if (getTable(Operands[1]->getStartLoc(), Inst, Type))
255 return true;
256 if (popType(ErrorLoc, wasm::ValType::I32))
257 return true;
258 Stack.push_back(Type);
259 } else if (Name == "table.set") {
260 if (getTable(Operands[1]->getStartLoc(), Inst, Type))
261 return true;
262 if (popType(ErrorLoc, Type))
263 return true;
264 if (popType(ErrorLoc, wasm::ValType::I32))
265 return true;
266 } else if (Name == "table.fill") {
267 if (getTable(Operands[1]->getStartLoc(), Inst, Type))
268 return true;
269 if (popType(ErrorLoc, wasm::ValType::I32))
270 return true;
271 if (popType(ErrorLoc, Type))
272 return true;
273 if (popType(ErrorLoc, wasm::ValType::I32))
274 return true;
275 } else if (Name == "drop") {
276 if (popType(ErrorLoc, {}))
277 return true;
278 } else if (Name == "end_block" || Name == "end_loop" || Name == "end_if" ||
279 Name == "else" || Name == "end_try") {
280 if (checkEnd(ErrorLoc, Name == "else"))
281 return true;
282 if (Name == "end_block")
283 Unreachable = false;
284 } else if (Name == "return") {
285 if (endOfFunction(ErrorLoc))
286 return true;
287 } else if (Name == "call_indirect" || Name == "return_call_indirect") {
288 // Function value.
289 if (popType(ErrorLoc, wasm::ValType::I32)) return true;
290 if (checkSig(ErrorLoc, LastSig)) return true;
291 if (Name == "return_call_indirect" && endOfFunction(ErrorLoc))
292 return true;
293 } else if (Name == "call" || Name == "return_call") {
294 const MCSymbolRefExpr *SymRef;
295 if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
296 return true;
297 auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
298 auto Sig = WasmSym->getSignature();
299 if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_FUNCTION)
300 return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
301 WasmSym->getName() +
302 " missing .functype");
303 if (checkSig(ErrorLoc, *Sig)) return true;
304 if (Name == "return_call" && endOfFunction(ErrorLoc))
305 return true;
306 } else if (Name == "catch") {
307 const MCSymbolRefExpr *SymRef;
308 if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
309 return true;
310 const auto *WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
311 const auto *Sig = WasmSym->getSignature();
312 if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_TAG)
313 return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
314 WasmSym->getName() +
315 " missing .tagtype");
316 // catch instruction pushes values whose types are specified in the tag's
317 // "params" part
318 Stack.insert(Stack.end(), Sig->Params.begin(), Sig->Params.end());
319 } else if (Name == "unreachable") {
320 Unreachable = true;
321 } else if (Name == "ref.is_null") {
322 if (popRefType(ErrorLoc))
323 return true;
324 Stack.push_back(wasm::ValType::I32);
325 } else {
326 // The current instruction is a stack instruction which doesn't have
327 // explicit operands that indicate push/pop types, so we get those from
328 // the register version of the same instruction.
329 auto RegOpc = WebAssembly::getRegisterOpcode(Opc);
330 assert(RegOpc != -1 && "Failed to get register version of MC instruction");
331 const auto &II = MII.get(RegOpc);
332 // First pop all the uses off the stack and check them.
333 for (unsigned I = II.getNumOperands(); I > II.getNumDefs(); I--) {
334 const auto &Op = II.operands()[I - 1];
335 if (Op.OperandType == MCOI::OPERAND_REGISTER) {
336 auto VT = WebAssembly::regClassToValType(Op.RegClass);
337 if (popType(ErrorLoc, VT))
338 return true;
339 }
340 }
341 // Now push all the defs onto the stack.
342 for (unsigned I = 0; I < II.getNumDefs(); I++) {
343 const auto &Op = II.operands()[I];
344 assert(Op.OperandType == MCOI::OPERAND_REGISTER && "Register expected");
345 auto VT = WebAssembly::regClassToValType(Op.RegClass);
346 Stack.push_back(VT);
347 }
348 }
349 return false;
350 }
351
352 } // end namespace llvm
353