1 #include "../include/KaleidoscopeJIT.h" 2 #include "llvm/ADT/APFloat.h" 3 #include "llvm/ADT/STLExtras.h" 4 #include "llvm/IR/BasicBlock.h" 5 #include "llvm/IR/Constants.h" 6 #include "llvm/IR/DerivedTypes.h" 7 #include "llvm/IR/Function.h" 8 #include "llvm/IR/IRBuilder.h" 9 #include "llvm/IR/LLVMContext.h" 10 #include "llvm/IR/LegacyPassManager.h" 11 #include "llvm/IR/Module.h" 12 #include "llvm/IR/Type.h" 13 #include "llvm/IR/Verifier.h" 14 #include "llvm/Support/TargetSelect.h" 15 #include "llvm/Target/TargetMachine.h" 16 #include "llvm/Transforms/InstCombine/InstCombine.h" 17 #include "llvm/Transforms/Scalar.h" 18 #include "llvm/Transforms/Scalar/GVN.h" 19 #include <algorithm> 20 #include <cassert> 21 #include <cctype> 22 #include <cstdint> 23 #include <cstdio> 24 #include <cstdlib> 25 #include <map> 26 #include <memory> 27 #include <string> 28 #include <vector> 29 30 using namespace llvm; 31 using namespace llvm::orc; 32 33 //===----------------------------------------------------------------------===// 34 // Lexer 35 //===----------------------------------------------------------------------===// 36 37 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one 38 // of these for known things. 39 enum Token { 40 tok_eof = -1, 41 42 // commands 43 tok_def = -2, 44 tok_extern = -3, 45 46 // primary 47 tok_identifier = -4, 48 tok_number = -5 49 }; 50 51 static std::string IdentifierStr; // Filled in if tok_identifier 52 static double NumVal; // Filled in if tok_number 53 54 /// gettok - Return the next token from standard input. 55 static int gettok() { 56 static int LastChar = ' '; 57 58 // Skip any whitespace. 59 while (isspace(LastChar)) 60 LastChar = getchar(); 61 62 if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]* 63 IdentifierStr = LastChar; 64 while (isalnum((LastChar = getchar()))) 65 IdentifierStr += LastChar; 66 67 if (IdentifierStr == "def") 68 return tok_def; 69 if (IdentifierStr == "extern") 70 return tok_extern; 71 return tok_identifier; 72 } 73 74 if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+ 75 std::string NumStr; 76 do { 77 NumStr += LastChar; 78 LastChar = getchar(); 79 } while (isdigit(LastChar) || LastChar == '.'); 80 81 NumVal = strtod(NumStr.c_str(), nullptr); 82 return tok_number; 83 } 84 85 if (LastChar == '#') { 86 // Comment until end of line. 87 do 88 LastChar = getchar(); 89 while (LastChar != EOF && LastChar != '\n' && LastChar != '\r'); 90 91 if (LastChar != EOF) 92 return gettok(); 93 } 94 95 // Check for end of file. Don't eat the EOF. 96 if (LastChar == EOF) 97 return tok_eof; 98 99 // Otherwise, just return the character as its ascii value. 100 int ThisChar = LastChar; 101 LastChar = getchar(); 102 return ThisChar; 103 } 104 105 //===----------------------------------------------------------------------===// 106 // Abstract Syntax Tree (aka Parse Tree) 107 //===----------------------------------------------------------------------===// 108 109 namespace { 110 111 /// ExprAST - Base class for all expression nodes. 112 class ExprAST { 113 public: 114 virtual ~ExprAST() = default; 115 116 virtual Value *codegen() = 0; 117 }; 118 119 /// NumberExprAST - Expression class for numeric literals like "1.0". 120 class NumberExprAST : public ExprAST { 121 double Val; 122 123 public: 124 NumberExprAST(double Val) : Val(Val) {} 125 126 Value *codegen() override; 127 }; 128 129 /// VariableExprAST - Expression class for referencing a variable, like "a". 130 class VariableExprAST : public ExprAST { 131 std::string Name; 132 133 public: 134 VariableExprAST(const std::string &Name) : Name(Name) {} 135 136 Value *codegen() override; 137 }; 138 139 /// BinaryExprAST - Expression class for a binary operator. 140 class BinaryExprAST : public ExprAST { 141 char Op; 142 std::unique_ptr<ExprAST> LHS, RHS; 143 144 public: 145 BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS, 146 std::unique_ptr<ExprAST> RHS) 147 : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {} 148 149 Value *codegen() override; 150 }; 151 152 /// CallExprAST - Expression class for function calls. 153 class CallExprAST : public ExprAST { 154 std::string Callee; 155 std::vector<std::unique_ptr<ExprAST>> Args; 156 157 public: 158 CallExprAST(const std::string &Callee, 159 std::vector<std::unique_ptr<ExprAST>> Args) 160 : Callee(Callee), Args(std::move(Args)) {} 161 162 Value *codegen() override; 163 }; 164 165 /// PrototypeAST - This class represents the "prototype" for a function, 166 /// which captures its name, and its argument names (thus implicitly the number 167 /// of arguments the function takes). 168 class PrototypeAST { 169 std::string Name; 170 std::vector<std::string> Args; 171 172 public: 173 PrototypeAST(const std::string &Name, std::vector<std::string> Args) 174 : Name(Name), Args(std::move(Args)) {} 175 176 Function *codegen(); 177 const std::string &getName() const { return Name; } 178 }; 179 180 /// FunctionAST - This class represents a function definition itself. 181 class FunctionAST { 182 std::unique_ptr<PrototypeAST> Proto; 183 std::unique_ptr<ExprAST> Body; 184 185 public: 186 FunctionAST(std::unique_ptr<PrototypeAST> Proto, 187 std::unique_ptr<ExprAST> Body) 188 : Proto(std::move(Proto)), Body(std::move(Body)) {} 189 190 Function *codegen(); 191 }; 192 193 } // end anonymous namespace 194 195 //===----------------------------------------------------------------------===// 196 // Parser 197 //===----------------------------------------------------------------------===// 198 199 /// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current 200 /// token the parser is looking at. getNextToken reads another token from the 201 /// lexer and updates CurTok with its results. 202 static int CurTok; 203 static int getNextToken() { return CurTok = gettok(); } 204 205 /// BinopPrecedence - This holds the precedence for each binary operator that is 206 /// defined. 207 static std::map<char, int> BinopPrecedence; 208 209 /// GetTokPrecedence - Get the precedence of the pending binary operator token. 210 static int GetTokPrecedence() { 211 if (!isascii(CurTok)) 212 return -1; 213 214 // Make sure it's a declared binop. 215 int TokPrec = BinopPrecedence[CurTok]; 216 if (TokPrec <= 0) 217 return -1; 218 return TokPrec; 219 } 220 221 /// LogError* - These are little helper functions for error handling. 222 std::unique_ptr<ExprAST> LogError(const char *Str) { 223 fprintf(stderr, "Error: %s\n", Str); 224 return nullptr; 225 } 226 227 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) { 228 LogError(Str); 229 return nullptr; 230 } 231 232 static std::unique_ptr<ExprAST> ParseExpression(); 233 234 /// numberexpr ::= number 235 static std::unique_ptr<ExprAST> ParseNumberExpr() { 236 auto Result = std::make_unique<NumberExprAST>(NumVal); 237 getNextToken(); // consume the number 238 return std::move(Result); 239 } 240 241 /// parenexpr ::= '(' expression ')' 242 static std::unique_ptr<ExprAST> ParseParenExpr() { 243 getNextToken(); // eat (. 244 auto V = ParseExpression(); 245 if (!V) 246 return nullptr; 247 248 if (CurTok != ')') 249 return LogError("expected ')'"); 250 getNextToken(); // eat ). 251 return V; 252 } 253 254 /// identifierexpr 255 /// ::= identifier 256 /// ::= identifier '(' expression* ')' 257 static std::unique_ptr<ExprAST> ParseIdentifierExpr() { 258 std::string IdName = IdentifierStr; 259 260 getNextToken(); // eat identifier. 261 262 if (CurTok != '(') // Simple variable ref. 263 return std::make_unique<VariableExprAST>(IdName); 264 265 // Call. 266 getNextToken(); // eat ( 267 std::vector<std::unique_ptr<ExprAST>> Args; 268 if (CurTok != ')') { 269 while (true) { 270 if (auto Arg = ParseExpression()) 271 Args.push_back(std::move(Arg)); 272 else 273 return nullptr; 274 275 if (CurTok == ')') 276 break; 277 278 if (CurTok != ',') 279 return LogError("Expected ')' or ',' in argument list"); 280 getNextToken(); 281 } 282 } 283 284 // Eat the ')'. 285 getNextToken(); 286 287 return std::make_unique<CallExprAST>(IdName, std::move(Args)); 288 } 289 290 /// primary 291 /// ::= identifierexpr 292 /// ::= numberexpr 293 /// ::= parenexpr 294 static std::unique_ptr<ExprAST> ParsePrimary() { 295 switch (CurTok) { 296 default: 297 return LogError("unknown token when expecting an expression"); 298 case tok_identifier: 299 return ParseIdentifierExpr(); 300 case tok_number: 301 return ParseNumberExpr(); 302 case '(': 303 return ParseParenExpr(); 304 } 305 } 306 307 /// binoprhs 308 /// ::= ('+' primary)* 309 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec, 310 std::unique_ptr<ExprAST> LHS) { 311 // If this is a binop, find its precedence. 312 while (true) { 313 int TokPrec = GetTokPrecedence(); 314 315 // If this is a binop that binds at least as tightly as the current binop, 316 // consume it, otherwise we are done. 317 if (TokPrec < ExprPrec) 318 return LHS; 319 320 // Okay, we know this is a binop. 321 int BinOp = CurTok; 322 getNextToken(); // eat binop 323 324 // Parse the primary expression after the binary operator. 325 auto RHS = ParsePrimary(); 326 if (!RHS) 327 return nullptr; 328 329 // If BinOp binds less tightly with RHS than the operator after RHS, let 330 // the pending operator take RHS as its LHS. 331 int NextPrec = GetTokPrecedence(); 332 if (TokPrec < NextPrec) { 333 RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS)); 334 if (!RHS) 335 return nullptr; 336 } 337 338 // Merge LHS/RHS. 339 LHS = 340 std::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS)); 341 } 342 } 343 344 /// expression 345 /// ::= primary binoprhs 346 /// 347 static std::unique_ptr<ExprAST> ParseExpression() { 348 auto LHS = ParsePrimary(); 349 if (!LHS) 350 return nullptr; 351 352 return ParseBinOpRHS(0, std::move(LHS)); 353 } 354 355 /// prototype 356 /// ::= id '(' id* ')' 357 static std::unique_ptr<PrototypeAST> ParsePrototype() { 358 if (CurTok != tok_identifier) 359 return LogErrorP("Expected function name in prototype"); 360 361 std::string FnName = IdentifierStr; 362 getNextToken(); 363 364 if (CurTok != '(') 365 return LogErrorP("Expected '(' in prototype"); 366 367 std::vector<std::string> ArgNames; 368 while (getNextToken() == tok_identifier) 369 ArgNames.push_back(IdentifierStr); 370 if (CurTok != ')') 371 return LogErrorP("Expected ')' in prototype"); 372 373 // success. 374 getNextToken(); // eat ')'. 375 376 return std::make_unique<PrototypeAST>(FnName, std::move(ArgNames)); 377 } 378 379 /// definition ::= 'def' prototype expression 380 static std::unique_ptr<FunctionAST> ParseDefinition() { 381 getNextToken(); // eat def. 382 auto Proto = ParsePrototype(); 383 if (!Proto) 384 return nullptr; 385 386 if (auto E = ParseExpression()) 387 return std::make_unique<FunctionAST>(std::move(Proto), std::move(E)); 388 return nullptr; 389 } 390 391 /// toplevelexpr ::= expression 392 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() { 393 if (auto E = ParseExpression()) { 394 // Make an anonymous proto. 395 auto Proto = std::make_unique<PrototypeAST>("__anon_expr", 396 std::vector<std::string>()); 397 return std::make_unique<FunctionAST>(std::move(Proto), std::move(E)); 398 } 399 return nullptr; 400 } 401 402 /// external ::= 'extern' prototype 403 static std::unique_ptr<PrototypeAST> ParseExtern() { 404 getNextToken(); // eat extern. 405 return ParsePrototype(); 406 } 407 408 //===----------------------------------------------------------------------===// 409 // Code Generation 410 //===----------------------------------------------------------------------===// 411 412 static LLVMContext TheContext; 413 static IRBuilder<> Builder(TheContext); 414 static std::unique_ptr<Module> TheModule; 415 static std::map<std::string, Value *> NamedValues; 416 static std::unique_ptr<legacy::FunctionPassManager> TheFPM; 417 static std::unique_ptr<KaleidoscopeJIT> TheJIT; 418 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos; 419 420 Value *LogErrorV(const char *Str) { 421 LogError(Str); 422 return nullptr; 423 } 424 425 Function *getFunction(std::string Name) { 426 // First, see if the function has already been added to the current module. 427 if (auto *F = TheModule->getFunction(Name)) 428 return F; 429 430 // If not, check whether we can codegen the declaration from some existing 431 // prototype. 432 auto FI = FunctionProtos.find(Name); 433 if (FI != FunctionProtos.end()) 434 return FI->second->codegen(); 435 436 // If no existing prototype exists, return null. 437 return nullptr; 438 } 439 440 Value *NumberExprAST::codegen() { 441 return ConstantFP::get(TheContext, APFloat(Val)); 442 } 443 444 Value *VariableExprAST::codegen() { 445 // Look this variable up in the function. 446 Value *V = NamedValues[Name]; 447 if (!V) 448 return LogErrorV("Unknown variable name"); 449 return V; 450 } 451 452 Value *BinaryExprAST::codegen() { 453 Value *L = LHS->codegen(); 454 Value *R = RHS->codegen(); 455 if (!L || !R) 456 return nullptr; 457 458 switch (Op) { 459 case '+': 460 return Builder.CreateFAdd(L, R, "addtmp"); 461 case '-': 462 return Builder.CreateFSub(L, R, "subtmp"); 463 case '*': 464 return Builder.CreateFMul(L, R, "multmp"); 465 case '<': 466 L = Builder.CreateFCmpULT(L, R, "cmptmp"); 467 // Convert bool 0/1 to double 0.0 or 1.0 468 return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp"); 469 default: 470 return LogErrorV("invalid binary operator"); 471 } 472 } 473 474 Value *CallExprAST::codegen() { 475 // Look up the name in the global module table. 476 Function *CalleeF = getFunction(Callee); 477 if (!CalleeF) 478 return LogErrorV("Unknown function referenced"); 479 480 // If argument mismatch error. 481 if (CalleeF->arg_size() != Args.size()) 482 return LogErrorV("Incorrect # arguments passed"); 483 484 std::vector<Value *> ArgsV; 485 for (unsigned i = 0, e = Args.size(); i != e; ++i) { 486 ArgsV.push_back(Args[i]->codegen()); 487 if (!ArgsV.back()) 488 return nullptr; 489 } 490 491 return Builder.CreateCall(CalleeF, ArgsV, "calltmp"); 492 } 493 494 Function *PrototypeAST::codegen() { 495 // Make the function type: double(double,double) etc. 496 std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(TheContext)); 497 FunctionType *FT = 498 FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false); 499 500 Function *F = 501 Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get()); 502 503 // Set names for all arguments. 504 unsigned Idx = 0; 505 for (auto &Arg : F->args()) 506 Arg.setName(Args[Idx++]); 507 508 return F; 509 } 510 511 Function *FunctionAST::codegen() { 512 // Transfer ownership of the prototype to the FunctionProtos map, but keep a 513 // reference to it for use below. 514 auto &P = *Proto; 515 FunctionProtos[Proto->getName()] = std::move(Proto); 516 Function *TheFunction = getFunction(P.getName()); 517 if (!TheFunction) 518 return nullptr; 519 520 // Create a new basic block to start insertion into. 521 BasicBlock *BB = BasicBlock::Create(TheContext, "entry", TheFunction); 522 Builder.SetInsertPoint(BB); 523 524 // Record the function arguments in the NamedValues map. 525 NamedValues.clear(); 526 for (auto &Arg : TheFunction->args()) 527 NamedValues[Arg.getName()] = &Arg; 528 529 if (Value *RetVal = Body->codegen()) { 530 // Finish off the function. 531 Builder.CreateRet(RetVal); 532 533 // Validate the generated code, checking for consistency. 534 verifyFunction(*TheFunction); 535 536 // Run the optimizer on the function. 537 TheFPM->run(*TheFunction); 538 539 return TheFunction; 540 } 541 542 // Error reading body, remove function. 543 TheFunction->eraseFromParent(); 544 return nullptr; 545 } 546 547 //===----------------------------------------------------------------------===// 548 // Top-Level parsing and JIT Driver 549 //===----------------------------------------------------------------------===// 550 551 static void InitializeModuleAndPassManager() { 552 // Open a new module. 553 TheModule = std::make_unique<Module>("my cool jit", TheContext); 554 TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout()); 555 556 // Create a new pass manager attached to it. 557 TheFPM = std::make_unique<legacy::FunctionPassManager>(TheModule.get()); 558 559 // Do simple "peephole" optimizations and bit-twiddling optzns. 560 TheFPM->add(createInstructionCombiningPass()); 561 // Reassociate expressions. 562 TheFPM->add(createReassociatePass()); 563 // Eliminate Common SubExpressions. 564 TheFPM->add(createGVNPass()); 565 // Simplify the control flow graph (deleting unreachable blocks, etc). 566 TheFPM->add(createCFGSimplificationPass()); 567 568 TheFPM->doInitialization(); 569 } 570 571 static void HandleDefinition() { 572 if (auto FnAST = ParseDefinition()) { 573 if (auto *FnIR = FnAST->codegen()) { 574 fprintf(stderr, "Read function definition:"); 575 FnIR->print(errs()); 576 fprintf(stderr, "\n"); 577 TheJIT->addModule(std::move(TheModule)); 578 InitializeModuleAndPassManager(); 579 } 580 } else { 581 // Skip token for error recovery. 582 getNextToken(); 583 } 584 } 585 586 static void HandleExtern() { 587 if (auto ProtoAST = ParseExtern()) { 588 if (auto *FnIR = ProtoAST->codegen()) { 589 fprintf(stderr, "Read extern: "); 590 FnIR->print(errs()); 591 fprintf(stderr, "\n"); 592 FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST); 593 } 594 } else { 595 // Skip token for error recovery. 596 getNextToken(); 597 } 598 } 599 600 static void HandleTopLevelExpression() { 601 // Evaluate a top-level expression into an anonymous function. 602 if (auto FnAST = ParseTopLevelExpr()) { 603 if (FnAST->codegen()) { 604 // JIT the module containing the anonymous expression, keeping a handle so 605 // we can free it later. 606 auto H = TheJIT->addModule(std::move(TheModule)); 607 InitializeModuleAndPassManager(); 608 609 // Search the JIT for the __anon_expr symbol. 610 auto ExprSymbol = TheJIT->findSymbol("__anon_expr"); 611 assert(ExprSymbol && "Function not found"); 612 613 // Get the symbol's address and cast it to the right type (takes no 614 // arguments, returns a double) so we can call it as a native function. 615 double (*FP)() = (double (*)())(intptr_t)cantFail(ExprSymbol.getAddress()); 616 fprintf(stderr, "Evaluated to %f\n", FP()); 617 618 // Delete the anonymous expression module from the JIT. 619 TheJIT->removeModule(H); 620 } 621 } else { 622 // Skip token for error recovery. 623 getNextToken(); 624 } 625 } 626 627 /// top ::= definition | external | expression | ';' 628 static void MainLoop() { 629 while (true) { 630 fprintf(stderr, "ready> "); 631 switch (CurTok) { 632 case tok_eof: 633 return; 634 case ';': // ignore top-level semicolons. 635 getNextToken(); 636 break; 637 case tok_def: 638 HandleDefinition(); 639 break; 640 case tok_extern: 641 HandleExtern(); 642 break; 643 default: 644 HandleTopLevelExpression(); 645 break; 646 } 647 } 648 } 649 650 //===----------------------------------------------------------------------===// 651 // "Library" functions that can be "extern'd" from user code. 652 //===----------------------------------------------------------------------===// 653 654 #ifdef _WIN32 655 #define DLLEXPORT __declspec(dllexport) 656 #else 657 #define DLLEXPORT 658 #endif 659 660 /// putchard - putchar that takes a double and returns 0. 661 extern "C" DLLEXPORT double putchard(double X) { 662 fputc((char)X, stderr); 663 return 0; 664 } 665 666 /// printd - printf that takes a double prints it as "%f\n", returning 0. 667 extern "C" DLLEXPORT double printd(double X) { 668 fprintf(stderr, "%f\n", X); 669 return 0; 670 } 671 672 //===----------------------------------------------------------------------===// 673 // Main driver code. 674 //===----------------------------------------------------------------------===// 675 676 int main() { 677 InitializeNativeTarget(); 678 InitializeNativeTargetAsmPrinter(); 679 InitializeNativeTargetAsmParser(); 680 681 // Install standard binary operators. 682 // 1 is lowest precedence. 683 BinopPrecedence['<'] = 10; 684 BinopPrecedence['+'] = 20; 685 BinopPrecedence['-'] = 20; 686 BinopPrecedence['*'] = 40; // highest. 687 688 // Prime the first token. 689 fprintf(stderr, "ready> "); 690 getNextToken(); 691 692 TheJIT = std::make_unique<KaleidoscopeJIT>(); 693 694 InitializeModuleAndPassManager(); 695 696 // Run the main "interpreter loop" now. 697 MainLoop(); 698 699 return 0; 700 } 701