1 //==- WebAssemblyAsmTypeCheck.cpp - Assembler for WebAssembly -*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file is part of the WebAssembly Assembler.
11 ///
12 /// It contains code to translate a parsed .s file into MCInsts.
13 ///
14 //===----------------------------------------------------------------------===//
15 
16 #include "AsmParser/WebAssemblyAsmTypeCheck.h"
17 #include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
18 #include "MCTargetDesc/WebAssemblyMCTypeUtilities.h"
19 #include "MCTargetDesc/WebAssemblyTargetStreamer.h"
20 #include "TargetInfo/WebAssemblyTargetInfo.h"
21 #include "WebAssembly.h"
22 #include "llvm/MC/MCContext.h"
23 #include "llvm/MC/MCExpr.h"
24 #include "llvm/MC/MCInst.h"
25 #include "llvm/MC/MCInstrInfo.h"
26 #include "llvm/MC/MCParser/MCParsedAsmOperand.h"
27 #include "llvm/MC/MCParser/MCTargetAsmParser.h"
28 #include "llvm/MC/MCSectionWasm.h"
29 #include "llvm/MC/MCStreamer.h"
30 #include "llvm/MC/MCSubtargetInfo.h"
31 #include "llvm/MC/MCSymbol.h"
32 #include "llvm/MC/MCSymbolWasm.h"
33 #include "llvm/MC/TargetRegistry.h"
34 #include "llvm/Support/Compiler.h"
35 #include "llvm/Support/Endian.h"
36 #include "llvm/Support/SourceMgr.h"
37 
38 using namespace llvm;
39 
40 #define DEBUG_TYPE "wasm-asm-parser"
41 
42 extern StringRef GetMnemonic(unsigned Opc);
43 
44 namespace llvm {
45 
46 WebAssemblyAsmTypeCheck::WebAssemblyAsmTypeCheck(MCAsmParser &Parser,
47                                                  const MCInstrInfo &MII,
48                                                  bool is64)
49     : Parser(Parser), MII(MII), is64(is64) {}
50 
51 void WebAssemblyAsmTypeCheck::funcDecl(const wasm::WasmSignature &Sig) {
52   LocalTypes.assign(Sig.Params.begin(), Sig.Params.end());
53   ReturnTypes.assign(Sig.Returns.begin(), Sig.Returns.end());
54   BrStack.emplace_back(Sig.Returns.begin(), Sig.Returns.end());
55 }
56 
57 void WebAssemblyAsmTypeCheck::localDecl(
58     const SmallVectorImpl<wasm::ValType> &Locals) {
59   LocalTypes.insert(LocalTypes.end(), Locals.begin(), Locals.end());
60 }
61 
62 void WebAssemblyAsmTypeCheck::dumpTypeStack(Twine Msg) {
63   LLVM_DEBUG({
64     std::string s;
65     for (auto VT : Stack) {
66       s += WebAssembly::typeToString(VT);
67       s += " ";
68     }
69     dbgs() << Msg << s << '\n';
70   });
71 }
72 
73 bool WebAssemblyAsmTypeCheck::typeError(SMLoc ErrorLoc, const Twine &Msg) {
74   // Once you get one type error in a function, it will likely trigger more
75   // which are mostly not helpful.
76   if (TypeErrorThisFunction)
77     return true;
78   // If we're currently in unreachable code, we suppress errors completely.
79   if (Unreachable)
80     return false;
81   TypeErrorThisFunction = true;
82   dumpTypeStack("current stack: ");
83   return Parser.Error(ErrorLoc, Msg);
84 }
85 
86 bool WebAssemblyAsmTypeCheck::popType(SMLoc ErrorLoc,
87                                       std::optional<wasm::ValType> EVT) {
88   if (Stack.empty()) {
89     return typeError(ErrorLoc,
90                      EVT ? StringRef("empty stack while popping ") +
91                                WebAssembly::typeToString(*EVT)
92                          : StringRef("empty stack while popping value"));
93   }
94   auto PVT = Stack.pop_back_val();
95   if (EVT && *EVT != PVT) {
96     return typeError(ErrorLoc,
97                      StringRef("popped ") + WebAssembly::typeToString(PVT) +
98                          ", expected " + WebAssembly::typeToString(*EVT));
99   }
100   return false;
101 }
102 
103 bool WebAssemblyAsmTypeCheck::popRefType(SMLoc ErrorLoc) {
104   if (Stack.empty()) {
105     return typeError(ErrorLoc, StringRef("empty stack while popping reftype"));
106   }
107   auto PVT = Stack.pop_back_val();
108   if (!WebAssembly::isRefType(PVT)) {
109     return typeError(ErrorLoc, StringRef("popped ") +
110                                    WebAssembly::typeToString(PVT) +
111                                    ", expected reftype");
112   }
113   return false;
114 }
115 
116 bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCInst &Inst,
117                                        wasm::ValType &Type) {
118   auto Local = static_cast<size_t>(Inst.getOperand(0).getImm());
119   if (Local >= LocalTypes.size())
120     return typeError(ErrorLoc, StringRef("no local type specified for index ") +
121                                    std::to_string(Local));
122   Type = LocalTypes[Local];
123   return false;
124 }
125 
126 static std::optional<std::string>
127 checkStackTop(const SmallVectorImpl<wasm::ValType> &ExpectedStackTop,
128               const SmallVectorImpl<wasm::ValType> &Got) {
129   for (size_t I = 0; I < ExpectedStackTop.size(); I++) {
130     auto EVT = ExpectedStackTop[I];
131     auto PVT = Got[Got.size() - ExpectedStackTop.size() + I];
132     if (PVT != EVT)
133       return std::string{"got "} + WebAssembly::typeToString(PVT) +
134              ", expected " + WebAssembly::typeToString(EVT);
135   }
136   return std::nullopt;
137 }
138 
139 bool WebAssemblyAsmTypeCheck::checkBr(SMLoc ErrorLoc, size_t Level) {
140   if (Level >= BrStack.size())
141     return typeError(ErrorLoc,
142                      StringRef("br: invalid depth ") + std::to_string(Level));
143   const SmallVector<wasm::ValType, 4> &Expected =
144       BrStack[BrStack.size() - Level - 1];
145   if (Expected.size() > Stack.size())
146     return typeError(ErrorLoc, "br: insufficient values on the type stack");
147   auto IsStackTopInvalid = checkStackTop(Expected, Stack);
148   if (IsStackTopInvalid)
149     return typeError(ErrorLoc, "br " + IsStackTopInvalid.value());
150   return false;
151 }
152 
153 bool WebAssemblyAsmTypeCheck::checkEnd(SMLoc ErrorLoc, bool PopVals) {
154   if (!PopVals)
155     BrStack.pop_back();
156   if (LastSig.Returns.size() > Stack.size())
157     return typeError(ErrorLoc, "end: insufficient values on the type stack");
158 
159   if (PopVals) {
160     for (auto VT : llvm::reverse(LastSig.Returns)) {
161       if (popType(ErrorLoc, VT))
162         return true;
163     }
164     return false;
165   }
166 
167   auto IsStackTopInvalid = checkStackTop(LastSig.Returns, Stack);
168   if (IsStackTopInvalid)
169     return typeError(ErrorLoc, "end " + IsStackTopInvalid.value());
170   return false;
171 }
172 
173 bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc,
174                                        const wasm::WasmSignature &Sig) {
175   for (auto VT : llvm::reverse(Sig.Params))
176     if (popType(ErrorLoc, VT))
177       return true;
178   Stack.insert(Stack.end(), Sig.Returns.begin(), Sig.Returns.end());
179   return false;
180 }
181 
182 bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCInst &Inst,
183                                         const MCSymbolRefExpr *&SymRef) {
184   auto Op = Inst.getOperand(0);
185   if (!Op.isExpr())
186     return typeError(ErrorLoc, StringRef("expected expression operand"));
187   SymRef = dyn_cast<MCSymbolRefExpr>(Op.getExpr());
188   if (!SymRef)
189     return typeError(ErrorLoc, StringRef("expected symbol operand"));
190   return false;
191 }
192 
193 bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCInst &Inst,
194                                         wasm::ValType &Type) {
195   const MCSymbolRefExpr *SymRef;
196   if (getSymRef(ErrorLoc, Inst, SymRef))
197     return true;
198   auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
199   switch (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA)) {
200   case wasm::WASM_SYMBOL_TYPE_GLOBAL:
201     Type = static_cast<wasm::ValType>(WasmSym->getGlobalType().Type);
202     break;
203   case wasm::WASM_SYMBOL_TYPE_FUNCTION:
204   case wasm::WASM_SYMBOL_TYPE_DATA:
205     switch (SymRef->getKind()) {
206     case MCSymbolRefExpr::VK_GOT:
207     case MCSymbolRefExpr::VK_WASM_GOT_TLS:
208       Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32;
209       return false;
210     default:
211       break;
212     }
213     [[fallthrough]];
214   default:
215     return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() +
216                                    " missing .globaltype");
217   }
218   return false;
219 }
220 
221 bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCInst &Inst,
222                                        wasm::ValType &Type) {
223   const MCSymbolRefExpr *SymRef;
224   if (getSymRef(ErrorLoc, Inst, SymRef))
225     return true;
226   auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
227   if (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA) !=
228       wasm::WASM_SYMBOL_TYPE_TABLE)
229     return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() +
230                                    " missing .tabletype");
231   Type = static_cast<wasm::ValType>(WasmSym->getTableType().ElemType);
232   return false;
233 }
234 
235 bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc) {
236   // Check the return types.
237   for (auto RVT : llvm::reverse(ReturnTypes)) {
238     if (popType(ErrorLoc, RVT))
239       return true;
240   }
241   if (!Stack.empty()) {
242     return typeError(ErrorLoc, std::to_string(Stack.size()) +
243                                    " superfluous return values");
244   }
245   Unreachable = true;
246   return false;
247 }
248 
249 bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
250                                         OperandVector &Operands) {
251   auto Opc = Inst.getOpcode();
252   auto Name = GetMnemonic(Opc);
253   dumpTypeStack("typechecking " + Name + ": ");
254   wasm::ValType Type;
255   if (Name == "local.get") {
256     if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
257       return true;
258     Stack.push_back(Type);
259   } else if (Name == "local.set") {
260     if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
261       return true;
262     if (popType(ErrorLoc, Type))
263       return true;
264   } else if (Name == "local.tee") {
265     if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
266       return true;
267     if (popType(ErrorLoc, Type))
268       return true;
269     Stack.push_back(Type);
270   } else if (Name == "global.get") {
271     if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
272       return true;
273     Stack.push_back(Type);
274   } else if (Name == "global.set") {
275     if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
276       return true;
277     if (popType(ErrorLoc, Type))
278       return true;
279   } else if (Name == "table.get") {
280     if (getTable(Operands[1]->getStartLoc(), Inst, Type))
281       return true;
282     if (popType(ErrorLoc, wasm::ValType::I32))
283       return true;
284     Stack.push_back(Type);
285   } else if (Name == "table.set") {
286     if (getTable(Operands[1]->getStartLoc(), Inst, Type))
287       return true;
288     if (popType(ErrorLoc, Type))
289       return true;
290     if (popType(ErrorLoc, wasm::ValType::I32))
291       return true;
292   } else if (Name == "table.fill") {
293     if (getTable(Operands[1]->getStartLoc(), Inst, Type))
294       return true;
295     if (popType(ErrorLoc, wasm::ValType::I32))
296       return true;
297     if (popType(ErrorLoc, Type))
298       return true;
299     if (popType(ErrorLoc, wasm::ValType::I32))
300       return true;
301   } else if (Name == "memory.fill") {
302     Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32;
303     if (popType(ErrorLoc, Type))
304       return true;
305     if (popType(ErrorLoc, wasm::ValType::I32))
306       return true;
307     if (popType(ErrorLoc, Type))
308       return true;
309   } else if (Name == "memory.copy") {
310     Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32;
311     if (popType(ErrorLoc, Type))
312       return true;
313     if (popType(ErrorLoc, Type))
314       return true;
315     if (popType(ErrorLoc, Type))
316       return true;
317   } else if (Name == "memory.init") {
318     Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32;
319     if (popType(ErrorLoc, wasm::ValType::I32))
320       return true;
321     if (popType(ErrorLoc, wasm::ValType::I32))
322       return true;
323     if (popType(ErrorLoc, Type))
324       return true;
325   } else if (Name == "drop") {
326     if (popType(ErrorLoc, {}))
327       return true;
328   } else if (Name == "try" || Name == "block" || Name == "loop" ||
329              Name == "if") {
330     if (Name == "if" && popType(ErrorLoc, wasm::ValType::I32))
331       return true;
332     if (Name == "loop")
333       BrStack.emplace_back(LastSig.Params.begin(), LastSig.Params.end());
334     else
335       BrStack.emplace_back(LastSig.Returns.begin(), LastSig.Returns.end());
336   } else if (Name == "end_block" || Name == "end_loop" || Name == "end_if" ||
337              Name == "else" || Name == "end_try" || Name == "catch" ||
338              Name == "catch_all" || Name == "delegate") {
339     if (checkEnd(ErrorLoc,
340                  Name == "else" || Name == "catch" || Name == "catch_all"))
341       return true;
342     Unreachable = false;
343     if (Name == "catch") {
344       const MCSymbolRefExpr *SymRef;
345       if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
346         return true;
347       const auto *WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
348       const auto *Sig = WasmSym->getSignature();
349       if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_TAG)
350         return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
351                                                          WasmSym->getName() +
352                                                          " missing .tagtype");
353       // catch instruction pushes values whose types are specified in the tag's
354       // "params" part
355       Stack.insert(Stack.end(), Sig->Params.begin(), Sig->Params.end());
356     }
357   } else if (Name == "br") {
358     const MCOperand &Operand = Inst.getOperand(0);
359     if (!Operand.isImm())
360       return false;
361     if (checkBr(ErrorLoc, static_cast<size_t>(Operand.getImm())))
362       return true;
363   } else if (Name == "return") {
364     if (endOfFunction(ErrorLoc))
365       return true;
366   } else if (Name == "call_indirect" || Name == "return_call_indirect") {
367     // Function value.
368     if (popType(ErrorLoc, wasm::ValType::I32))
369       return true;
370     if (checkSig(ErrorLoc, LastSig))
371       return true;
372     if (Name == "return_call_indirect" && endOfFunction(ErrorLoc))
373       return true;
374   } else if (Name == "call" || Name == "return_call") {
375     const MCSymbolRefExpr *SymRef;
376     if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
377       return true;
378     auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
379     auto Sig = WasmSym->getSignature();
380     if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_FUNCTION)
381       return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
382                                                        WasmSym->getName() +
383                                                        " missing .functype");
384     if (checkSig(ErrorLoc, *Sig))
385       return true;
386     if (Name == "return_call" && endOfFunction(ErrorLoc))
387       return true;
388   } else if (Name == "unreachable") {
389     Unreachable = true;
390   } else if (Name == "ref.is_null") {
391     if (popRefType(ErrorLoc))
392       return true;
393     Stack.push_back(wasm::ValType::I32);
394   } else {
395     // The current instruction is a stack instruction which doesn't have
396     // explicit operands that indicate push/pop types, so we get those from
397     // the register version of the same instruction.
398     auto RegOpc = WebAssembly::getRegisterOpcode(Opc);
399     assert(RegOpc != -1 && "Failed to get register version of MC instruction");
400     const auto &II = MII.get(RegOpc);
401     // First pop all the uses off the stack and check them.
402     for (unsigned I = II.getNumOperands(); I > II.getNumDefs(); I--) {
403       const auto &Op = II.operands()[I - 1];
404       if (Op.OperandType == MCOI::OPERAND_REGISTER) {
405         auto VT = WebAssembly::regClassToValType(Op.RegClass);
406         if (popType(ErrorLoc, VT))
407           return true;
408       }
409     }
410     // Now push all the defs onto the stack.
411     for (unsigned I = 0; I < II.getNumDefs(); I++) {
412       const auto &Op = II.operands()[I];
413       assert(Op.OperandType == MCOI::OPERAND_REGISTER && "Register expected");
414       auto VT = WebAssembly::regClassToValType(Op.RegClass);
415       Stack.push_back(VT);
416     }
417   }
418   return false;
419 }
420 
421 } // end namespace llvm
422