1 //==- WebAssemblyAsmTypeCheck.cpp - Assembler for WebAssembly -*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file is part of the WebAssembly Assembler.
11 ///
12 /// It contains code to translate a parsed .s file into MCInsts.
13 ///
14 //===----------------------------------------------------------------------===//
15 
16 #include "AsmParser/WebAssemblyAsmTypeCheck.h"
17 #include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
18 #include "MCTargetDesc/WebAssemblyTargetStreamer.h"
19 #include "TargetInfo/WebAssemblyTargetInfo.h"
20 #include "Utils/WebAssemblyTypeUtilities.h"
21 #include "Utils/WebAssemblyUtilities.h"
22 #include "WebAssembly.h"
23 #include "llvm/MC/MCContext.h"
24 #include "llvm/MC/MCExpr.h"
25 #include "llvm/MC/MCInst.h"
26 #include "llvm/MC/MCInstrInfo.h"
27 #include "llvm/MC/MCParser/MCParsedAsmOperand.h"
28 #include "llvm/MC/MCParser/MCTargetAsmParser.h"
29 #include "llvm/MC/MCSectionWasm.h"
30 #include "llvm/MC/MCStreamer.h"
31 #include "llvm/MC/MCSubtargetInfo.h"
32 #include "llvm/MC/MCSymbol.h"
33 #include "llvm/MC/MCSymbolWasm.h"
34 #include "llvm/MC/TargetRegistry.h"
35 #include "llvm/Support/Compiler.h"
36 #include "llvm/Support/Endian.h"
37 #include "llvm/Support/SourceMgr.h"
38 
39 using namespace llvm;
40 
41 #define DEBUG_TYPE "wasm-asm-parser"
42 
43 extern StringRef GetMnemonic(unsigned Opc);
44 
45 namespace llvm {
46 
47 WebAssemblyAsmTypeCheck::WebAssemblyAsmTypeCheck(MCAsmParser &Parser,
48                                                  const MCInstrInfo &MII, bool is64)
49     : Parser(Parser), MII(MII), is64(is64) {
50 }
51 
52 void WebAssemblyAsmTypeCheck::funcDecl(const wasm::WasmSignature &Sig) {
53   LocalTypes.assign(Sig.Params.begin(), Sig.Params.end());
54   ReturnTypes.assign(Sig.Returns.begin(), Sig.Returns.end());
55 }
56 
57 void WebAssemblyAsmTypeCheck::localDecl(const SmallVector<wasm::ValType, 4> &Locals) {
58   LocalTypes.insert(LocalTypes.end(), Locals.begin(), Locals.end());
59 }
60 
61 void WebAssemblyAsmTypeCheck::dumpTypeStack(Twine Msg) {
62   LLVM_DEBUG({
63     std::string s;
64     for (auto VT : Stack) {
65       s += WebAssembly::typeToString(VT);
66       s += " ";
67     }
68     dbgs() << Msg << s << '\n';
69   });
70 }
71 
72 bool WebAssemblyAsmTypeCheck::typeError(SMLoc ErrorLoc, const Twine &Msg) {
73   // Once you get one type error in a function, it will likely trigger more
74   // which are mostly not helpful.
75   if (TypeErrorThisFunction)
76     return true;
77   // If we're currently in unreachable code, we suppress errors completely.
78   if (Unreachable)
79     return false;
80   TypeErrorThisFunction = true;
81   dumpTypeStack("current stack: ");
82   return Parser.Error(ErrorLoc, Msg);
83 }
84 
85 bool WebAssemblyAsmTypeCheck::popType(SMLoc ErrorLoc,
86                                       std::optional<wasm::ValType> EVT) {
87   if (Stack.empty()) {
88     return typeError(ErrorLoc,
89                      EVT ? StringRef("empty stack while popping ") +
90                                WebAssembly::typeToString(*EVT)
91                          : StringRef("empty stack while popping value"));
92   }
93   auto PVT = Stack.pop_back_val();
94   if (EVT && *EVT != PVT) {
95     return typeError(ErrorLoc,
96                      StringRef("popped ") + WebAssembly::typeToString(PVT) +
97                          ", expected " + WebAssembly::typeToString(*EVT));
98   }
99   return false;
100 }
101 
102 bool WebAssemblyAsmTypeCheck::popRefType(SMLoc ErrorLoc) {
103   if (Stack.empty()) {
104     return typeError(ErrorLoc, StringRef("empty stack while popping reftype"));
105   }
106   auto PVT = Stack.pop_back_val();
107   if (!WebAssembly::isRefType(PVT)) {
108     return typeError(ErrorLoc, StringRef("popped ") +
109                                    WebAssembly::typeToString(PVT) +
110                                    ", expected reftype");
111   }
112   return false;
113 }
114 
115 bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCInst &Inst,
116                                        wasm::ValType &Type) {
117   auto Local = static_cast<size_t>(Inst.getOperand(0).getImm());
118   if (Local >= LocalTypes.size())
119     return typeError(ErrorLoc, StringRef("no local type specified for index ") +
120                           std::to_string(Local));
121   Type = LocalTypes[Local];
122   return false;
123 }
124 
125 bool WebAssemblyAsmTypeCheck::checkEnd(SMLoc ErrorLoc, bool PopVals) {
126   if (LastSig.Returns.size() > Stack.size())
127     return typeError(ErrorLoc, "end: insufficient values on the type stack");
128 
129   if (PopVals) {
130     for (auto VT : llvm::reverse(LastSig.Returns)) {
131       if (popType(ErrorLoc, VT))
132         return true;
133     }
134     return false;
135   }
136 
137   for (size_t i = 0; i < LastSig.Returns.size(); i++) {
138     auto EVT = LastSig.Returns[i];
139     auto PVT = Stack[Stack.size() - LastSig.Returns.size() + i];
140     if (PVT != EVT)
141       return typeError(
142           ErrorLoc, StringRef("end got ") + WebAssembly::typeToString(PVT) +
143                         ", expected " + WebAssembly::typeToString(EVT));
144   }
145   return false;
146 }
147 
148 bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc,
149                                        const wasm::WasmSignature& Sig) {
150   for (auto VT : llvm::reverse(Sig.Params))
151     if (popType(ErrorLoc, VT)) return true;
152   Stack.insert(Stack.end(), Sig.Returns.begin(), Sig.Returns.end());
153   return false;
154 }
155 
156 bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCInst &Inst,
157                                         const MCSymbolRefExpr *&SymRef) {
158   auto Op = Inst.getOperand(0);
159   if (!Op.isExpr())
160     return typeError(ErrorLoc, StringRef("expected expression operand"));
161   SymRef = dyn_cast<MCSymbolRefExpr>(Op.getExpr());
162   if (!SymRef)
163     return typeError(ErrorLoc, StringRef("expected symbol operand"));
164   return false;
165 }
166 
167 bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCInst &Inst,
168                                         wasm::ValType &Type) {
169   const MCSymbolRefExpr *SymRef;
170   if (getSymRef(ErrorLoc, Inst, SymRef))
171     return true;
172   auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
173   switch (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA)) {
174   case wasm::WASM_SYMBOL_TYPE_GLOBAL:
175     Type = static_cast<wasm::ValType>(WasmSym->getGlobalType().Type);
176     break;
177   case wasm::WASM_SYMBOL_TYPE_FUNCTION:
178   case wasm::WASM_SYMBOL_TYPE_DATA:
179     switch (SymRef->getKind()) {
180     case MCSymbolRefExpr::VK_GOT:
181     case MCSymbolRefExpr::VK_WASM_GOT_TLS:
182       Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32;
183       return false;
184     default:
185       break;
186     }
187     [[fallthrough]];
188   default:
189     return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() +
190                                     " missing .globaltype");
191   }
192   return false;
193 }
194 
195 bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCInst &Inst,
196                                        wasm::ValType &Type) {
197   const MCSymbolRefExpr *SymRef;
198   if (getSymRef(ErrorLoc, Inst, SymRef))
199     return true;
200   auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
201   if (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA) !=
202       wasm::WASM_SYMBOL_TYPE_TABLE)
203     return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() +
204                                    " missing .tabletype");
205   Type = static_cast<wasm::ValType>(WasmSym->getTableType().ElemType);
206   return false;
207 }
208 
209 bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc) {
210   // Check the return types.
211   for (auto RVT : llvm::reverse(ReturnTypes)) {
212     if (popType(ErrorLoc, RVT))
213       return true;
214   }
215   if (!Stack.empty()) {
216     return typeError(ErrorLoc, std::to_string(Stack.size()) +
217                                    " superfluous return values");
218   }
219   Unreachable = true;
220   return false;
221 }
222 
223 bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
224                                         OperandVector &Operands) {
225   auto Opc = Inst.getOpcode();
226   auto Name = GetMnemonic(Opc);
227   dumpTypeStack("typechecking " + Name + ": ");
228   wasm::ValType Type;
229   if (Name == "local.get") {
230     if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
231       return true;
232     Stack.push_back(Type);
233   } else if (Name == "local.set") {
234     if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
235       return true;
236     if (popType(ErrorLoc, Type))
237       return true;
238   } else if (Name == "local.tee") {
239     if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
240       return true;
241     if (popType(ErrorLoc, Type))
242       return true;
243     Stack.push_back(Type);
244   } else if (Name == "global.get") {
245     if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
246       return true;
247     Stack.push_back(Type);
248   } else if (Name == "global.set") {
249     if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
250       return true;
251     if (popType(ErrorLoc, Type))
252       return true;
253   } else if (Name == "table.get") {
254     if (getTable(Operands[1]->getStartLoc(), Inst, Type))
255       return true;
256     if (popType(ErrorLoc, wasm::ValType::I32))
257       return true;
258     Stack.push_back(Type);
259   } else if (Name == "table.set") {
260     if (getTable(Operands[1]->getStartLoc(), Inst, Type))
261       return true;
262     if (popType(ErrorLoc, Type))
263       return true;
264     if (popType(ErrorLoc, wasm::ValType::I32))
265       return true;
266   } else if (Name == "table.fill") {
267     if (getTable(Operands[1]->getStartLoc(), Inst, Type))
268       return true;
269     if (popType(ErrorLoc, wasm::ValType::I32))
270       return true;
271     if (popType(ErrorLoc, Type))
272       return true;
273     if (popType(ErrorLoc, wasm::ValType::I32))
274       return true;
275   } else if (Name == "drop") {
276     if (popType(ErrorLoc, {}))
277       return true;
278   } else if (Name == "end_block" || Name == "end_loop" || Name == "end_if" ||
279              Name == "else" || Name == "end_try") {
280     if (checkEnd(ErrorLoc, Name == "else"))
281       return true;
282     if (Name == "end_block")
283       Unreachable = false;
284   } else if (Name == "return") {
285     if (endOfFunction(ErrorLoc))
286       return true;
287   } else if (Name == "call_indirect" || Name == "return_call_indirect") {
288     // Function value.
289     if (popType(ErrorLoc, wasm::ValType::I32)) return true;
290     if (checkSig(ErrorLoc, LastSig)) return true;
291     if (Name == "return_call_indirect" && endOfFunction(ErrorLoc))
292       return true;
293   } else if (Name == "call" || Name == "return_call") {
294     const MCSymbolRefExpr *SymRef;
295     if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
296       return true;
297     auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
298     auto Sig = WasmSym->getSignature();
299     if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_FUNCTION)
300       return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
301                                                        WasmSym->getName() +
302                                                        " missing .functype");
303     if (checkSig(ErrorLoc, *Sig)) return true;
304     if (Name == "return_call" && endOfFunction(ErrorLoc))
305       return true;
306   } else if (Name == "catch") {
307     const MCSymbolRefExpr *SymRef;
308     if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
309       return true;
310     const auto *WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
311     const auto *Sig = WasmSym->getSignature();
312     if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_TAG)
313       return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
314                                                        WasmSym->getName() +
315                                                        " missing .tagtype");
316     // catch instruction pushes values whose types are specified in the tag's
317     // "params" part
318     Stack.insert(Stack.end(), Sig->Params.begin(), Sig->Params.end());
319   } else if (Name == "unreachable") {
320     Unreachable = true;
321   } else if (Name == "ref.is_null") {
322     if (popRefType(ErrorLoc))
323       return true;
324     Stack.push_back(wasm::ValType::I32);
325   } else {
326     // The current instruction is a stack instruction which doesn't have
327     // explicit operands that indicate push/pop types, so we get those from
328     // the register version of the same instruction.
329     auto RegOpc = WebAssembly::getRegisterOpcode(Opc);
330     assert(RegOpc != -1 && "Failed to get register version of MC instruction");
331     const auto &II = MII.get(RegOpc);
332     // First pop all the uses off the stack and check them.
333     for (unsigned I = II.getNumOperands(); I > II.getNumDefs(); I--) {
334       const auto &Op = II.operands()[I - 1];
335       if (Op.OperandType == MCOI::OPERAND_REGISTER) {
336         auto VT = WebAssembly::regClassToValType(Op.RegClass);
337         if (popType(ErrorLoc, VT))
338           return true;
339       }
340     }
341     // Now push all the defs onto the stack.
342     for (unsigned I = 0; I < II.getNumDefs(); I++) {
343       const auto &Op = II.operands()[I];
344       assert(Op.OperandType == MCOI::OPERAND_REGISTER && "Register expected");
345       auto VT = WebAssembly::regClassToValType(Op.RegClass);
346       Stack.push_back(VT);
347     }
348   }
349   return false;
350 }
351 
352 } // end namespace llvm
353