1 //==- WebAssemblyAsmTypeCheck.cpp - Assembler for WebAssembly -*- C++ -*-==// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file is part of the WebAssembly Assembler. 11 /// 12 /// It contains code to translate a parsed .s file into MCInsts. 13 /// 14 //===----------------------------------------------------------------------===// 15 16 #include "AsmParser/WebAssemblyAsmTypeCheck.h" 17 #include "MCTargetDesc/WebAssemblyMCTargetDesc.h" 18 #include "MCTargetDesc/WebAssemblyMCTypeUtilities.h" 19 #include "MCTargetDesc/WebAssemblyTargetStreamer.h" 20 #include "TargetInfo/WebAssemblyTargetInfo.h" 21 #include "WebAssembly.h" 22 #include "llvm/MC/MCContext.h" 23 #include "llvm/MC/MCExpr.h" 24 #include "llvm/MC/MCInst.h" 25 #include "llvm/MC/MCInstrInfo.h" 26 #include "llvm/MC/MCParser/MCParsedAsmOperand.h" 27 #include "llvm/MC/MCParser/MCTargetAsmParser.h" 28 #include "llvm/MC/MCSectionWasm.h" 29 #include "llvm/MC/MCStreamer.h" 30 #include "llvm/MC/MCSubtargetInfo.h" 31 #include "llvm/MC/MCSymbol.h" 32 #include "llvm/MC/MCSymbolWasm.h" 33 #include "llvm/MC/TargetRegistry.h" 34 #include "llvm/Support/Compiler.h" 35 #include "llvm/Support/Endian.h" 36 #include "llvm/Support/SourceMgr.h" 37 38 using namespace llvm; 39 40 #define DEBUG_TYPE "wasm-asm-parser" 41 42 extern StringRef GetMnemonic(unsigned Opc); 43 44 namespace llvm { 45 46 WebAssemblyAsmTypeCheck::WebAssemblyAsmTypeCheck(MCAsmParser &Parser, 47 const MCInstrInfo &MII, 48 bool is64) 49 : Parser(Parser), MII(MII), is64(is64) {} 50 51 void WebAssemblyAsmTypeCheck::funcDecl(const wasm::WasmSignature &Sig) { 52 LocalTypes.assign(Sig.Params.begin(), Sig.Params.end()); 53 ReturnTypes.assign(Sig.Returns.begin(), Sig.Returns.end()); 54 BrStack.emplace_back(Sig.Returns.begin(), Sig.Returns.end()); 55 } 56 57 void WebAssemblyAsmTypeCheck::localDecl( 58 const SmallVectorImpl<wasm::ValType> &Locals) { 59 LocalTypes.insert(LocalTypes.end(), Locals.begin(), Locals.end()); 60 } 61 62 void WebAssemblyAsmTypeCheck::dumpTypeStack(Twine Msg) { 63 LLVM_DEBUG({ 64 std::string s; 65 for (auto VT : Stack) { 66 s += WebAssembly::typeToString(VT); 67 s += " "; 68 } 69 dbgs() << Msg << s << '\n'; 70 }); 71 } 72 73 bool WebAssemblyAsmTypeCheck::typeError(SMLoc ErrorLoc, const Twine &Msg) { 74 // Once you get one type error in a function, it will likely trigger more 75 // which are mostly not helpful. 76 if (TypeErrorThisFunction) 77 return true; 78 // If we're currently in unreachable code, we suppress errors completely. 79 if (Unreachable) 80 return false; 81 TypeErrorThisFunction = true; 82 dumpTypeStack("current stack: "); 83 return Parser.Error(ErrorLoc, Msg); 84 } 85 86 bool WebAssemblyAsmTypeCheck::popType(SMLoc ErrorLoc, 87 std::optional<wasm::ValType> EVT) { 88 if (Stack.empty()) { 89 return typeError(ErrorLoc, 90 EVT ? StringRef("empty stack while popping ") + 91 WebAssembly::typeToString(*EVT) 92 : StringRef("empty stack while popping value")); 93 } 94 auto PVT = Stack.pop_back_val(); 95 if (EVT && *EVT != PVT) { 96 return typeError(ErrorLoc, 97 StringRef("popped ") + WebAssembly::typeToString(PVT) + 98 ", expected " + WebAssembly::typeToString(*EVT)); 99 } 100 return false; 101 } 102 103 bool WebAssemblyAsmTypeCheck::popRefType(SMLoc ErrorLoc) { 104 if (Stack.empty()) { 105 return typeError(ErrorLoc, StringRef("empty stack while popping reftype")); 106 } 107 auto PVT = Stack.pop_back_val(); 108 if (!WebAssembly::isRefType(PVT)) { 109 return typeError(ErrorLoc, StringRef("popped ") + 110 WebAssembly::typeToString(PVT) + 111 ", expected reftype"); 112 } 113 return false; 114 } 115 116 bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCInst &Inst, 117 wasm::ValType &Type) { 118 auto Local = static_cast<size_t>(Inst.getOperand(0).getImm()); 119 if (Local >= LocalTypes.size()) 120 return typeError(ErrorLoc, StringRef("no local type specified for index ") + 121 std::to_string(Local)); 122 Type = LocalTypes[Local]; 123 return false; 124 } 125 126 static std::optional<std::string> 127 checkStackTop(const SmallVectorImpl<wasm::ValType> &ExpectedStackTop, 128 const SmallVectorImpl<wasm::ValType> &Got) { 129 for (size_t I = 0; I < ExpectedStackTop.size(); I++) { 130 auto EVT = ExpectedStackTop[I]; 131 auto PVT = Got[Got.size() - ExpectedStackTop.size() + I]; 132 if (PVT != EVT) 133 return std::string{"got "} + WebAssembly::typeToString(PVT) + 134 ", expected " + WebAssembly::typeToString(EVT); 135 } 136 return std::nullopt; 137 } 138 139 bool WebAssemblyAsmTypeCheck::checkBr(SMLoc ErrorLoc, size_t Level) { 140 if (Level >= BrStack.size()) 141 return typeError(ErrorLoc, 142 StringRef("br: invalid depth ") + std::to_string(Level)); 143 const SmallVector<wasm::ValType, 4> &Expected = 144 BrStack[BrStack.size() - Level - 1]; 145 if (Expected.size() > Stack.size()) 146 return typeError(ErrorLoc, "br: insufficient values on the type stack"); 147 auto IsStackTopInvalid = checkStackTop(Expected, Stack); 148 if (IsStackTopInvalid) 149 return typeError(ErrorLoc, "br " + IsStackTopInvalid.value()); 150 return false; 151 } 152 153 bool WebAssemblyAsmTypeCheck::checkEnd(SMLoc ErrorLoc, bool PopVals) { 154 if (!PopVals) 155 BrStack.pop_back(); 156 if (LastSig.Returns.size() > Stack.size()) 157 return typeError(ErrorLoc, "end: insufficient values on the type stack"); 158 159 if (PopVals) { 160 for (auto VT : llvm::reverse(LastSig.Returns)) { 161 if (popType(ErrorLoc, VT)) 162 return true; 163 } 164 return false; 165 } 166 167 auto IsStackTopInvalid = checkStackTop(LastSig.Returns, Stack); 168 if (IsStackTopInvalid) 169 return typeError(ErrorLoc, "end " + IsStackTopInvalid.value()); 170 return false; 171 } 172 173 bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc, 174 const wasm::WasmSignature &Sig) { 175 for (auto VT : llvm::reverse(Sig.Params)) 176 if (popType(ErrorLoc, VT)) 177 return true; 178 Stack.insert(Stack.end(), Sig.Returns.begin(), Sig.Returns.end()); 179 return false; 180 } 181 182 bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCInst &Inst, 183 const MCSymbolRefExpr *&SymRef) { 184 auto Op = Inst.getOperand(0); 185 if (!Op.isExpr()) 186 return typeError(ErrorLoc, StringRef("expected expression operand")); 187 SymRef = dyn_cast<MCSymbolRefExpr>(Op.getExpr()); 188 if (!SymRef) 189 return typeError(ErrorLoc, StringRef("expected symbol operand")); 190 return false; 191 } 192 193 bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCInst &Inst, 194 wasm::ValType &Type) { 195 const MCSymbolRefExpr *SymRef; 196 if (getSymRef(ErrorLoc, Inst, SymRef)) 197 return true; 198 auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol()); 199 switch (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA)) { 200 case wasm::WASM_SYMBOL_TYPE_GLOBAL: 201 Type = static_cast<wasm::ValType>(WasmSym->getGlobalType().Type); 202 break; 203 case wasm::WASM_SYMBOL_TYPE_FUNCTION: 204 case wasm::WASM_SYMBOL_TYPE_DATA: 205 switch (SymRef->getKind()) { 206 case MCSymbolRefExpr::VK_GOT: 207 case MCSymbolRefExpr::VK_WASM_GOT_TLS: 208 Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32; 209 return false; 210 default: 211 break; 212 } 213 [[fallthrough]]; 214 default: 215 return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() + 216 " missing .globaltype"); 217 } 218 return false; 219 } 220 221 bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCInst &Inst, 222 wasm::ValType &Type) { 223 const MCSymbolRefExpr *SymRef; 224 if (getSymRef(ErrorLoc, Inst, SymRef)) 225 return true; 226 auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol()); 227 if (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA) != 228 wasm::WASM_SYMBOL_TYPE_TABLE) 229 return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() + 230 " missing .tabletype"); 231 Type = static_cast<wasm::ValType>(WasmSym->getTableType().ElemType); 232 return false; 233 } 234 235 bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc) { 236 // Check the return types. 237 for (auto RVT : llvm::reverse(ReturnTypes)) { 238 if (popType(ErrorLoc, RVT)) 239 return true; 240 } 241 if (!Stack.empty()) { 242 return typeError(ErrorLoc, std::to_string(Stack.size()) + 243 " superfluous return values"); 244 } 245 Unreachable = true; 246 return false; 247 } 248 249 bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst, 250 OperandVector &Operands) { 251 auto Opc = Inst.getOpcode(); 252 auto Name = GetMnemonic(Opc); 253 dumpTypeStack("typechecking " + Name + ": "); 254 wasm::ValType Type; 255 if (Name == "local.get") { 256 if (getLocal(Operands[1]->getStartLoc(), Inst, Type)) 257 return true; 258 Stack.push_back(Type); 259 } else if (Name == "local.set") { 260 if (getLocal(Operands[1]->getStartLoc(), Inst, Type)) 261 return true; 262 if (popType(ErrorLoc, Type)) 263 return true; 264 } else if (Name == "local.tee") { 265 if (getLocal(Operands[1]->getStartLoc(), Inst, Type)) 266 return true; 267 if (popType(ErrorLoc, Type)) 268 return true; 269 Stack.push_back(Type); 270 } else if (Name == "global.get") { 271 if (getGlobal(Operands[1]->getStartLoc(), Inst, Type)) 272 return true; 273 Stack.push_back(Type); 274 } else if (Name == "global.set") { 275 if (getGlobal(Operands[1]->getStartLoc(), Inst, Type)) 276 return true; 277 if (popType(ErrorLoc, Type)) 278 return true; 279 } else if (Name == "table.get") { 280 if (getTable(Operands[1]->getStartLoc(), Inst, Type)) 281 return true; 282 if (popType(ErrorLoc, wasm::ValType::I32)) 283 return true; 284 Stack.push_back(Type); 285 } else if (Name == "table.set") { 286 if (getTable(Operands[1]->getStartLoc(), Inst, Type)) 287 return true; 288 if (popType(ErrorLoc, Type)) 289 return true; 290 if (popType(ErrorLoc, wasm::ValType::I32)) 291 return true; 292 } else if (Name == "table.fill") { 293 if (getTable(Operands[1]->getStartLoc(), Inst, Type)) 294 return true; 295 if (popType(ErrorLoc, wasm::ValType::I32)) 296 return true; 297 if (popType(ErrorLoc, Type)) 298 return true; 299 if (popType(ErrorLoc, wasm::ValType::I32)) 300 return true; 301 } else if (Name == "memory.fill") { 302 Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32; 303 if (popType(ErrorLoc, Type)) 304 return true; 305 if (popType(ErrorLoc, wasm::ValType::I32)) 306 return true; 307 if (popType(ErrorLoc, Type)) 308 return true; 309 } else if (Name == "memory.copy") { 310 Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32; 311 if (popType(ErrorLoc, Type)) 312 return true; 313 if (popType(ErrorLoc, Type)) 314 return true; 315 if (popType(ErrorLoc, Type)) 316 return true; 317 } else if (Name == "memory.init") { 318 Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32; 319 if (popType(ErrorLoc, wasm::ValType::I32)) 320 return true; 321 if (popType(ErrorLoc, wasm::ValType::I32)) 322 return true; 323 if (popType(ErrorLoc, Type)) 324 return true; 325 } else if (Name == "drop") { 326 if (popType(ErrorLoc, {})) 327 return true; 328 } else if (Name == "try" || Name == "block" || Name == "loop" || 329 Name == "if") { 330 if (Name == "if" && popType(ErrorLoc, wasm::ValType::I32)) 331 return true; 332 if (Name == "loop") 333 BrStack.emplace_back(LastSig.Params.begin(), LastSig.Params.end()); 334 else 335 BrStack.emplace_back(LastSig.Returns.begin(), LastSig.Returns.end()); 336 } else if (Name == "end_block" || Name == "end_loop" || Name == "end_if" || 337 Name == "else" || Name == "end_try" || Name == "catch" || 338 Name == "catch_all" || Name == "delegate") { 339 if (checkEnd(ErrorLoc, 340 Name == "else" || Name == "catch" || Name == "catch_all")) 341 return true; 342 Unreachable = false; 343 if (Name == "catch") { 344 const MCSymbolRefExpr *SymRef; 345 if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef)) 346 return true; 347 const auto *WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol()); 348 const auto *Sig = WasmSym->getSignature(); 349 if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_TAG) 350 return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") + 351 WasmSym->getName() + 352 " missing .tagtype"); 353 // catch instruction pushes values whose types are specified in the tag's 354 // "params" part 355 Stack.insert(Stack.end(), Sig->Params.begin(), Sig->Params.end()); 356 } 357 } else if (Name == "br") { 358 const MCOperand &Operand = Inst.getOperand(0); 359 if (!Operand.isImm()) 360 return false; 361 if (checkBr(ErrorLoc, static_cast<size_t>(Operand.getImm()))) 362 return true; 363 } else if (Name == "return") { 364 if (endOfFunction(ErrorLoc)) 365 return true; 366 } else if (Name == "call_indirect" || Name == "return_call_indirect") { 367 // Function value. 368 if (popType(ErrorLoc, wasm::ValType::I32)) 369 return true; 370 if (checkSig(ErrorLoc, LastSig)) 371 return true; 372 if (Name == "return_call_indirect" && endOfFunction(ErrorLoc)) 373 return true; 374 } else if (Name == "call" || Name == "return_call") { 375 const MCSymbolRefExpr *SymRef; 376 if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef)) 377 return true; 378 auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol()); 379 auto Sig = WasmSym->getSignature(); 380 if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_FUNCTION) 381 return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") + 382 WasmSym->getName() + 383 " missing .functype"); 384 if (checkSig(ErrorLoc, *Sig)) 385 return true; 386 if (Name == "return_call" && endOfFunction(ErrorLoc)) 387 return true; 388 } else if (Name == "unreachable") { 389 Unreachable = true; 390 } else if (Name == "ref.is_null") { 391 if (popRefType(ErrorLoc)) 392 return true; 393 Stack.push_back(wasm::ValType::I32); 394 } else { 395 // The current instruction is a stack instruction which doesn't have 396 // explicit operands that indicate push/pop types, so we get those from 397 // the register version of the same instruction. 398 auto RegOpc = WebAssembly::getRegisterOpcode(Opc); 399 assert(RegOpc != -1 && "Failed to get register version of MC instruction"); 400 const auto &II = MII.get(RegOpc); 401 // First pop all the uses off the stack and check them. 402 for (unsigned I = II.getNumOperands(); I > II.getNumDefs(); I--) { 403 const auto &Op = II.operands()[I - 1]; 404 if (Op.OperandType == MCOI::OPERAND_REGISTER) { 405 auto VT = WebAssembly::regClassToValType(Op.RegClass); 406 if (popType(ErrorLoc, VT)) 407 return true; 408 } 409 } 410 // Now push all the defs onto the stack. 411 for (unsigned I = 0; I < II.getNumDefs(); I++) { 412 const auto &Op = II.operands()[I]; 413 assert(Op.OperandType == MCOI::OPERAND_REGISTER && "Register expected"); 414 auto VT = WebAssembly::regClassToValType(Op.RegClass); 415 Stack.push_back(VT); 416 } 417 } 418 return false; 419 } 420 421 } // end namespace llvm 422