1 #include "llvm/ADT/APFloat.h"
2 #include "llvm/ADT/STLExtras.h"
3 #include "llvm/IR/BasicBlock.h"
4 #include "llvm/IR/Constants.h"
5 #include "llvm/IR/DerivedTypes.h"
6 #include "llvm/IR/Function.h"
7 #include "llvm/IR/IRBuilder.h"
8 #include "llvm/IR/LLVMContext.h"
9 #include "llvm/IR/Module.h"
10 #include "llvm/IR/Type.h"
11 #include "llvm/IR/Verifier.h"
12 #include <algorithm>
13 #include <cctype>
14 #include <cstdio>
15 #include <cstdlib>
16 #include <map>
17 #include <memory>
18 #include <string>
19 #include <vector>
20 
21 using namespace llvm;
22 
23 //===----------------------------------------------------------------------===//
24 // Lexer
25 //===----------------------------------------------------------------------===//
26 
27 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
28 // of these for known things.
29 enum Token {
30   tok_eof = -1,
31 
32   // commands
33   tok_def = -2,
34   tok_extern = -3,
35 
36   // primary
37   tok_identifier = -4,
38   tok_number = -5
39 };
40 
41 static std::string IdentifierStr; // Filled in if tok_identifier
42 static double NumVal;             // Filled in if tok_number
43 
44 /// gettok - Return the next token from standard input.
gettok()45 static int gettok() {
46   static int LastChar = ' ';
47 
48   // Skip any whitespace.
49   while (isspace(LastChar))
50     LastChar = getchar();
51 
52   if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
53     IdentifierStr = LastChar;
54     while (isalnum((LastChar = getchar())))
55       IdentifierStr += LastChar;
56 
57     if (IdentifierStr == "def")
58       return tok_def;
59     if (IdentifierStr == "extern")
60       return tok_extern;
61     return tok_identifier;
62   }
63 
64   if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
65     std::string NumStr;
66     do {
67       NumStr += LastChar;
68       LastChar = getchar();
69     } while (isdigit(LastChar) || LastChar == '.');
70 
71     NumVal = strtod(NumStr.c_str(), nullptr);
72     return tok_number;
73   }
74 
75   if (LastChar == '#') {
76     // Comment until end of line.
77     do
78       LastChar = getchar();
79     while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
80 
81     if (LastChar != EOF)
82       return gettok();
83   }
84 
85   // Check for end of file.  Don't eat the EOF.
86   if (LastChar == EOF)
87     return tok_eof;
88 
89   // Otherwise, just return the character as its ascii value.
90   int ThisChar = LastChar;
91   LastChar = getchar();
92   return ThisChar;
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // Abstract Syntax Tree (aka Parse Tree)
97 //===----------------------------------------------------------------------===//
98 
99 namespace {
100 
101 /// ExprAST - Base class for all expression nodes.
102 class ExprAST {
103 public:
104   virtual ~ExprAST() = default;
105 
106   virtual Value *codegen() = 0;
107 };
108 
109 /// NumberExprAST - Expression class for numeric literals like "1.0".
110 class NumberExprAST : public ExprAST {
111   double Val;
112 
113 public:
NumberExprAST(double Val)114   NumberExprAST(double Val) : Val(Val) {}
115 
116   Value *codegen() override;
117 };
118 
119 /// VariableExprAST - Expression class for referencing a variable, like "a".
120 class VariableExprAST : public ExprAST {
121   std::string Name;
122 
123 public:
VariableExprAST(const std::string & Name)124   VariableExprAST(const std::string &Name) : Name(Name) {}
125 
126   Value *codegen() override;
127 };
128 
129 /// BinaryExprAST - Expression class for a binary operator.
130 class BinaryExprAST : public ExprAST {
131   char Op;
132   std::unique_ptr<ExprAST> LHS, RHS;
133 
134 public:
BinaryExprAST(char Op,std::unique_ptr<ExprAST> LHS,std::unique_ptr<ExprAST> RHS)135   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
136                 std::unique_ptr<ExprAST> RHS)
137       : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
138 
139   Value *codegen() override;
140 };
141 
142 /// CallExprAST - Expression class for function calls.
143 class CallExprAST : public ExprAST {
144   std::string Callee;
145   std::vector<std::unique_ptr<ExprAST>> Args;
146 
147 public:
CallExprAST(const std::string & Callee,std::vector<std::unique_ptr<ExprAST>> Args)148   CallExprAST(const std::string &Callee,
149               std::vector<std::unique_ptr<ExprAST>> Args)
150       : Callee(Callee), Args(std::move(Args)) {}
151 
152   Value *codegen() override;
153 };
154 
155 /// PrototypeAST - This class represents the "prototype" for a function,
156 /// which captures its name, and its argument names (thus implicitly the number
157 /// of arguments the function takes).
158 class PrototypeAST {
159   std::string Name;
160   std::vector<std::string> Args;
161 
162 public:
PrototypeAST(const std::string & Name,std::vector<std::string> Args)163   PrototypeAST(const std::string &Name, std::vector<std::string> Args)
164       : Name(Name), Args(std::move(Args)) {}
165 
166   Function *codegen();
getName() const167   const std::string &getName() const { return Name; }
168 };
169 
170 /// FunctionAST - This class represents a function definition itself.
171 class FunctionAST {
172   std::unique_ptr<PrototypeAST> Proto;
173   std::unique_ptr<ExprAST> Body;
174 
175 public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,std::unique_ptr<ExprAST> Body)176   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
177               std::unique_ptr<ExprAST> Body)
178       : Proto(std::move(Proto)), Body(std::move(Body)) {}
179 
180   Function *codegen();
181 };
182 
183 } // end anonymous namespace
184 
185 //===----------------------------------------------------------------------===//
186 // Parser
187 //===----------------------------------------------------------------------===//
188 
189 /// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
190 /// token the parser is looking at.  getNextToken reads another token from the
191 /// lexer and updates CurTok with its results.
192 static int CurTok;
getNextToken()193 static int getNextToken() { return CurTok = gettok(); }
194 
195 /// BinopPrecedence - This holds the precedence for each binary operator that is
196 /// defined.
197 static std::map<char, int> BinopPrecedence;
198 
199 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
GetTokPrecedence()200 static int GetTokPrecedence() {
201   if (!isascii(CurTok))
202     return -1;
203 
204   // Make sure it's a declared binop.
205   int TokPrec = BinopPrecedence[CurTok];
206   if (TokPrec <= 0)
207     return -1;
208   return TokPrec;
209 }
210 
211 /// LogError* - These are little helper functions for error handling.
LogError(const char * Str)212 std::unique_ptr<ExprAST> LogError(const char *Str) {
213   fprintf(stderr, "Error: %s\n", Str);
214   return nullptr;
215 }
216 
LogErrorP(const char * Str)217 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
218   LogError(Str);
219   return nullptr;
220 }
221 
222 static std::unique_ptr<ExprAST> ParseExpression();
223 
224 /// numberexpr ::= number
ParseNumberExpr()225 static std::unique_ptr<ExprAST> ParseNumberExpr() {
226   auto Result = std::make_unique<NumberExprAST>(NumVal);
227   getNextToken(); // consume the number
228   return std::move(Result);
229 }
230 
231 /// parenexpr ::= '(' expression ')'
ParseParenExpr()232 static std::unique_ptr<ExprAST> ParseParenExpr() {
233   getNextToken(); // eat (.
234   auto V = ParseExpression();
235   if (!V)
236     return nullptr;
237 
238   if (CurTok != ')')
239     return LogError("expected ')'");
240   getNextToken(); // eat ).
241   return V;
242 }
243 
244 /// identifierexpr
245 ///   ::= identifier
246 ///   ::= identifier '(' expression* ')'
ParseIdentifierExpr()247 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
248   std::string IdName = IdentifierStr;
249 
250   getNextToken(); // eat identifier.
251 
252   if (CurTok != '(') // Simple variable ref.
253     return std::make_unique<VariableExprAST>(IdName);
254 
255   // Call.
256   getNextToken(); // eat (
257   std::vector<std::unique_ptr<ExprAST>> Args;
258   if (CurTok != ')') {
259     while (true) {
260       if (auto Arg = ParseExpression())
261         Args.push_back(std::move(Arg));
262       else
263         return nullptr;
264 
265       if (CurTok == ')')
266         break;
267 
268       if (CurTok != ',')
269         return LogError("Expected ')' or ',' in argument list");
270       getNextToken();
271     }
272   }
273 
274   // Eat the ')'.
275   getNextToken();
276 
277   return std::make_unique<CallExprAST>(IdName, std::move(Args));
278 }
279 
280 /// primary
281 ///   ::= identifierexpr
282 ///   ::= numberexpr
283 ///   ::= parenexpr
ParsePrimary()284 static std::unique_ptr<ExprAST> ParsePrimary() {
285   switch (CurTok) {
286   default:
287     return LogError("unknown token when expecting an expression");
288   case tok_identifier:
289     return ParseIdentifierExpr();
290   case tok_number:
291     return ParseNumberExpr();
292   case '(':
293     return ParseParenExpr();
294   }
295 }
296 
297 /// binoprhs
298 ///   ::= ('+' primary)*
ParseBinOpRHS(int ExprPrec,std::unique_ptr<ExprAST> LHS)299 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
300                                               std::unique_ptr<ExprAST> LHS) {
301   // If this is a binop, find its precedence.
302   while (true) {
303     int TokPrec = GetTokPrecedence();
304 
305     // If this is a binop that binds at least as tightly as the current binop,
306     // consume it, otherwise we are done.
307     if (TokPrec < ExprPrec)
308       return LHS;
309 
310     // Okay, we know this is a binop.
311     int BinOp = CurTok;
312     getNextToken(); // eat binop
313 
314     // Parse the primary expression after the binary operator.
315     auto RHS = ParsePrimary();
316     if (!RHS)
317       return nullptr;
318 
319     // If BinOp binds less tightly with RHS than the operator after RHS, let
320     // the pending operator take RHS as its LHS.
321     int NextPrec = GetTokPrecedence();
322     if (TokPrec < NextPrec) {
323       RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
324       if (!RHS)
325         return nullptr;
326     }
327 
328     // Merge LHS/RHS.
329     LHS =
330         std::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
331   }
332 }
333 
334 /// expression
335 ///   ::= primary binoprhs
336 ///
ParseExpression()337 static std::unique_ptr<ExprAST> ParseExpression() {
338   auto LHS = ParsePrimary();
339   if (!LHS)
340     return nullptr;
341 
342   return ParseBinOpRHS(0, std::move(LHS));
343 }
344 
345 /// prototype
346 ///   ::= id '(' id* ')'
ParsePrototype()347 static std::unique_ptr<PrototypeAST> ParsePrototype() {
348   if (CurTok != tok_identifier)
349     return LogErrorP("Expected function name in prototype");
350 
351   std::string FnName = IdentifierStr;
352   getNextToken();
353 
354   if (CurTok != '(')
355     return LogErrorP("Expected '(' in prototype");
356 
357   std::vector<std::string> ArgNames;
358   while (getNextToken() == tok_identifier)
359     ArgNames.push_back(IdentifierStr);
360   if (CurTok != ')')
361     return LogErrorP("Expected ')' in prototype");
362 
363   // success.
364   getNextToken(); // eat ')'.
365 
366   return std::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
367 }
368 
369 /// definition ::= 'def' prototype expression
ParseDefinition()370 static std::unique_ptr<FunctionAST> ParseDefinition() {
371   getNextToken(); // eat def.
372   auto Proto = ParsePrototype();
373   if (!Proto)
374     return nullptr;
375 
376   if (auto E = ParseExpression())
377     return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
378   return nullptr;
379 }
380 
381 /// toplevelexpr ::= expression
ParseTopLevelExpr()382 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
383   if (auto E = ParseExpression()) {
384     // Make an anonymous proto.
385     auto Proto = std::make_unique<PrototypeAST>("__anon_expr",
386                                                  std::vector<std::string>());
387     return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
388   }
389   return nullptr;
390 }
391 
392 /// external ::= 'extern' prototype
ParseExtern()393 static std::unique_ptr<PrototypeAST> ParseExtern() {
394   getNextToken(); // eat extern.
395   return ParsePrototype();
396 }
397 
398 //===----------------------------------------------------------------------===//
399 // Code Generation
400 //===----------------------------------------------------------------------===//
401 
402 static std::unique_ptr<LLVMContext> TheContext;
403 static std::unique_ptr<Module> TheModule;
404 static std::unique_ptr<IRBuilder<>> Builder;
405 static std::map<std::string, Value *> NamedValues;
406 
LogErrorV(const char * Str)407 Value *LogErrorV(const char *Str) {
408   LogError(Str);
409   return nullptr;
410 }
411 
codegen()412 Value *NumberExprAST::codegen() {
413   return ConstantFP::get(*TheContext, APFloat(Val));
414 }
415 
codegen()416 Value *VariableExprAST::codegen() {
417   // Look this variable up in the function.
418   Value *V = NamedValues[Name];
419   if (!V)
420     return LogErrorV("Unknown variable name");
421   return V;
422 }
423 
codegen()424 Value *BinaryExprAST::codegen() {
425   Value *L = LHS->codegen();
426   Value *R = RHS->codegen();
427   if (!L || !R)
428     return nullptr;
429 
430   switch (Op) {
431   case '+':
432     return Builder->CreateFAdd(L, R, "addtmp");
433   case '-':
434     return Builder->CreateFSub(L, R, "subtmp");
435   case '*':
436     return Builder->CreateFMul(L, R, "multmp");
437   case '<':
438     L = Builder->CreateFCmpULT(L, R, "cmptmp");
439     // Convert bool 0/1 to double 0.0 or 1.0
440     return Builder->CreateUIToFP(L, Type::getDoubleTy(*TheContext), "booltmp");
441   default:
442     return LogErrorV("invalid binary operator");
443   }
444 }
445 
codegen()446 Value *CallExprAST::codegen() {
447   // Look up the name in the global module table.
448   Function *CalleeF = TheModule->getFunction(Callee);
449   if (!CalleeF)
450     return LogErrorV("Unknown function referenced");
451 
452   // If argument mismatch error.
453   if (CalleeF->arg_size() != Args.size())
454     return LogErrorV("Incorrect # arguments passed");
455 
456   std::vector<Value *> ArgsV;
457   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
458     ArgsV.push_back(Args[i]->codegen());
459     if (!ArgsV.back())
460       return nullptr;
461   }
462 
463   return Builder->CreateCall(CalleeF, ArgsV, "calltmp");
464 }
465 
codegen()466 Function *PrototypeAST::codegen() {
467   // Make the function type:  double(double,double) etc.
468   std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(*TheContext));
469   FunctionType *FT =
470       FunctionType::get(Type::getDoubleTy(*TheContext), Doubles, false);
471 
472   Function *F =
473       Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
474 
475   // Set names for all arguments.
476   unsigned Idx = 0;
477   for (auto &Arg : F->args())
478     Arg.setName(Args[Idx++]);
479 
480   return F;
481 }
482 
codegen()483 Function *FunctionAST::codegen() {
484   // First, check for an existing function from a previous 'extern' declaration.
485   Function *TheFunction = TheModule->getFunction(Proto->getName());
486 
487   if (!TheFunction)
488     TheFunction = Proto->codegen();
489 
490   if (!TheFunction)
491     return nullptr;
492 
493   // Create a new basic block to start insertion into.
494   BasicBlock *BB = BasicBlock::Create(*TheContext, "entry", TheFunction);
495   Builder->SetInsertPoint(BB);
496 
497   // Record the function arguments in the NamedValues map.
498   NamedValues.clear();
499   for (auto &Arg : TheFunction->args())
500     NamedValues[std::string(Arg.getName())] = &Arg;
501 
502   if (Value *RetVal = Body->codegen()) {
503     // Finish off the function.
504     Builder->CreateRet(RetVal);
505 
506     // Validate the generated code, checking for consistency.
507     verifyFunction(*TheFunction);
508 
509     return TheFunction;
510   }
511 
512   // Error reading body, remove function.
513   TheFunction->eraseFromParent();
514   return nullptr;
515 }
516 
517 //===----------------------------------------------------------------------===//
518 // Top-Level parsing and JIT Driver
519 //===----------------------------------------------------------------------===//
520 
InitializeModule()521 static void InitializeModule() {
522   // Open a new context and module.
523   TheContext = std::make_unique<LLVMContext>();
524   TheModule = std::make_unique<Module>("my cool jit", *TheContext);
525 
526   // Create a new builder for the module.
527   Builder = std::make_unique<IRBuilder<>>(*TheContext);
528 }
529 
HandleDefinition()530 static void HandleDefinition() {
531   if (auto FnAST = ParseDefinition()) {
532     if (auto *FnIR = FnAST->codegen()) {
533       fprintf(stderr, "Read function definition:");
534       FnIR->print(errs());
535       fprintf(stderr, "\n");
536     }
537   } else {
538     // Skip token for error recovery.
539     getNextToken();
540   }
541 }
542 
HandleExtern()543 static void HandleExtern() {
544   if (auto ProtoAST = ParseExtern()) {
545     if (auto *FnIR = ProtoAST->codegen()) {
546       fprintf(stderr, "Read extern: ");
547       FnIR->print(errs());
548       fprintf(stderr, "\n");
549     }
550   } else {
551     // Skip token for error recovery.
552     getNextToken();
553   }
554 }
555 
HandleTopLevelExpression()556 static void HandleTopLevelExpression() {
557   // Evaluate a top-level expression into an anonymous function.
558   if (auto FnAST = ParseTopLevelExpr()) {
559     if (auto *FnIR = FnAST->codegen()) {
560       fprintf(stderr, "Read top-level expression:");
561       FnIR->print(errs());
562       fprintf(stderr, "\n");
563 
564       // Remove the anonymous expression.
565       FnIR->eraseFromParent();
566     }
567   } else {
568     // Skip token for error recovery.
569     getNextToken();
570   }
571 }
572 
573 /// top ::= definition | external | expression | ';'
MainLoop()574 static void MainLoop() {
575   while (true) {
576     fprintf(stderr, "ready> ");
577     switch (CurTok) {
578     case tok_eof:
579       return;
580     case ';': // ignore top-level semicolons.
581       getNextToken();
582       break;
583     case tok_def:
584       HandleDefinition();
585       break;
586     case tok_extern:
587       HandleExtern();
588       break;
589     default:
590       HandleTopLevelExpression();
591       break;
592     }
593   }
594 }
595 
596 //===----------------------------------------------------------------------===//
597 // Main driver code.
598 //===----------------------------------------------------------------------===//
599 
main()600 int main() {
601   // Install standard binary operators.
602   // 1 is lowest precedence.
603   BinopPrecedence['<'] = 10;
604   BinopPrecedence['+'] = 20;
605   BinopPrecedence['-'] = 20;
606   BinopPrecedence['*'] = 40; // highest.
607 
608   // Prime the first token.
609   fprintf(stderr, "ready> ");
610   getNextToken();
611 
612   // Make the module, which holds all the code.
613   InitializeModule();
614 
615   // Run the main "interpreter loop" now.
616   MainLoop();
617 
618   // Print out all of the generated code.
619   TheModule->print(errs(), nullptr);
620 
621   return 0;
622 }
623