1 #include "llvm/ADT/APFloat.h"
2 #include "llvm/ADT/STLExtras.h"
3 #include "llvm/IR/BasicBlock.h"
4 #include "llvm/IR/Constants.h"
5 #include "llvm/IR/DerivedTypes.h"
6 #include "llvm/IR/Function.h"
7 #include "llvm/IR/IRBuilder.h"
8 #include "llvm/IR/LLVMContext.h"
9 #include "llvm/IR/Module.h"
10 #include "llvm/IR/Type.h"
11 #include "llvm/IR/Verifier.h"
12 #include <algorithm>
13 #include <cctype>
14 #include <cstdio>
15 #include <cstdlib>
16 #include <map>
17 #include <memory>
18 #include <string>
19 #include <vector>
20
21 using namespace llvm;
22
23 //===----------------------------------------------------------------------===//
24 // Lexer
25 //===----------------------------------------------------------------------===//
26
27 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
28 // of these for known things.
29 enum Token {
30 tok_eof = -1,
31
32 // commands
33 tok_def = -2,
34 tok_extern = -3,
35
36 // primary
37 tok_identifier = -4,
38 tok_number = -5
39 };
40
41 static std::string IdentifierStr; // Filled in if tok_identifier
42 static double NumVal; // Filled in if tok_number
43
44 /// gettok - Return the next token from standard input.
gettok()45 static int gettok() {
46 static int LastChar = ' ';
47
48 // Skip any whitespace.
49 while (isspace(LastChar))
50 LastChar = getchar();
51
52 if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
53 IdentifierStr = LastChar;
54 while (isalnum((LastChar = getchar())))
55 IdentifierStr += LastChar;
56
57 if (IdentifierStr == "def")
58 return tok_def;
59 if (IdentifierStr == "extern")
60 return tok_extern;
61 return tok_identifier;
62 }
63
64 if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
65 std::string NumStr;
66 do {
67 NumStr += LastChar;
68 LastChar = getchar();
69 } while (isdigit(LastChar) || LastChar == '.');
70
71 NumVal = strtod(NumStr.c_str(), nullptr);
72 return tok_number;
73 }
74
75 if (LastChar == '#') {
76 // Comment until end of line.
77 do
78 LastChar = getchar();
79 while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
80
81 if (LastChar != EOF)
82 return gettok();
83 }
84
85 // Check for end of file. Don't eat the EOF.
86 if (LastChar == EOF)
87 return tok_eof;
88
89 // Otherwise, just return the character as its ascii value.
90 int ThisChar = LastChar;
91 LastChar = getchar();
92 return ThisChar;
93 }
94
95 //===----------------------------------------------------------------------===//
96 // Abstract Syntax Tree (aka Parse Tree)
97 //===----------------------------------------------------------------------===//
98
99 namespace {
100
101 /// ExprAST - Base class for all expression nodes.
102 class ExprAST {
103 public:
104 virtual ~ExprAST() = default;
105
106 virtual Value *codegen() = 0;
107 };
108
109 /// NumberExprAST - Expression class for numeric literals like "1.0".
110 class NumberExprAST : public ExprAST {
111 double Val;
112
113 public:
NumberExprAST(double Val)114 NumberExprAST(double Val) : Val(Val) {}
115
116 Value *codegen() override;
117 };
118
119 /// VariableExprAST - Expression class for referencing a variable, like "a".
120 class VariableExprAST : public ExprAST {
121 std::string Name;
122
123 public:
VariableExprAST(const std::string & Name)124 VariableExprAST(const std::string &Name) : Name(Name) {}
125
126 Value *codegen() override;
127 };
128
129 /// BinaryExprAST - Expression class for a binary operator.
130 class BinaryExprAST : public ExprAST {
131 char Op;
132 std::unique_ptr<ExprAST> LHS, RHS;
133
134 public:
BinaryExprAST(char Op,std::unique_ptr<ExprAST> LHS,std::unique_ptr<ExprAST> RHS)135 BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
136 std::unique_ptr<ExprAST> RHS)
137 : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
138
139 Value *codegen() override;
140 };
141
142 /// CallExprAST - Expression class for function calls.
143 class CallExprAST : public ExprAST {
144 std::string Callee;
145 std::vector<std::unique_ptr<ExprAST>> Args;
146
147 public:
CallExprAST(const std::string & Callee,std::vector<std::unique_ptr<ExprAST>> Args)148 CallExprAST(const std::string &Callee,
149 std::vector<std::unique_ptr<ExprAST>> Args)
150 : Callee(Callee), Args(std::move(Args)) {}
151
152 Value *codegen() override;
153 };
154
155 /// PrototypeAST - This class represents the "prototype" for a function,
156 /// which captures its name, and its argument names (thus implicitly the number
157 /// of arguments the function takes).
158 class PrototypeAST {
159 std::string Name;
160 std::vector<std::string> Args;
161
162 public:
PrototypeAST(const std::string & Name,std::vector<std::string> Args)163 PrototypeAST(const std::string &Name, std::vector<std::string> Args)
164 : Name(Name), Args(std::move(Args)) {}
165
166 Function *codegen();
getName() const167 const std::string &getName() const { return Name; }
168 };
169
170 /// FunctionAST - This class represents a function definition itself.
171 class FunctionAST {
172 std::unique_ptr<PrototypeAST> Proto;
173 std::unique_ptr<ExprAST> Body;
174
175 public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,std::unique_ptr<ExprAST> Body)176 FunctionAST(std::unique_ptr<PrototypeAST> Proto,
177 std::unique_ptr<ExprAST> Body)
178 : Proto(std::move(Proto)), Body(std::move(Body)) {}
179
180 Function *codegen();
181 };
182
183 } // end anonymous namespace
184
185 //===----------------------------------------------------------------------===//
186 // Parser
187 //===----------------------------------------------------------------------===//
188
189 /// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current
190 /// token the parser is looking at. getNextToken reads another token from the
191 /// lexer and updates CurTok with its results.
192 static int CurTok;
getNextToken()193 static int getNextToken() { return CurTok = gettok(); }
194
195 /// BinopPrecedence - This holds the precedence for each binary operator that is
196 /// defined.
197 static std::map<char, int> BinopPrecedence;
198
199 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
GetTokPrecedence()200 static int GetTokPrecedence() {
201 if (!isascii(CurTok))
202 return -1;
203
204 // Make sure it's a declared binop.
205 int TokPrec = BinopPrecedence[CurTok];
206 if (TokPrec <= 0)
207 return -1;
208 return TokPrec;
209 }
210
211 /// LogError* - These are little helper functions for error handling.
LogError(const char * Str)212 std::unique_ptr<ExprAST> LogError(const char *Str) {
213 fprintf(stderr, "Error: %s\n", Str);
214 return nullptr;
215 }
216
LogErrorP(const char * Str)217 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
218 LogError(Str);
219 return nullptr;
220 }
221
222 static std::unique_ptr<ExprAST> ParseExpression();
223
224 /// numberexpr ::= number
ParseNumberExpr()225 static std::unique_ptr<ExprAST> ParseNumberExpr() {
226 auto Result = llvm::make_unique<NumberExprAST>(NumVal);
227 getNextToken(); // consume the number
228 return std::move(Result);
229 }
230
231 /// parenexpr ::= '(' expression ')'
ParseParenExpr()232 static std::unique_ptr<ExprAST> ParseParenExpr() {
233 getNextToken(); // eat (.
234 auto V = ParseExpression();
235 if (!V)
236 return nullptr;
237
238 if (CurTok != ')')
239 return LogError("expected ')'");
240 getNextToken(); // eat ).
241 return V;
242 }
243
244 /// identifierexpr
245 /// ::= identifier
246 /// ::= identifier '(' expression* ')'
ParseIdentifierExpr()247 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
248 std::string IdName = IdentifierStr;
249
250 getNextToken(); // eat identifier.
251
252 if (CurTok != '(') // Simple variable ref.
253 return llvm::make_unique<VariableExprAST>(IdName);
254
255 // Call.
256 getNextToken(); // eat (
257 std::vector<std::unique_ptr<ExprAST>> Args;
258 if (CurTok != ')') {
259 while (true) {
260 if (auto Arg = ParseExpression())
261 Args.push_back(std::move(Arg));
262 else
263 return nullptr;
264
265 if (CurTok == ')')
266 break;
267
268 if (CurTok != ',')
269 return LogError("Expected ')' or ',' in argument list");
270 getNextToken();
271 }
272 }
273
274 // Eat the ')'.
275 getNextToken();
276
277 return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
278 }
279
280 /// primary
281 /// ::= identifierexpr
282 /// ::= numberexpr
283 /// ::= parenexpr
ParsePrimary()284 static std::unique_ptr<ExprAST> ParsePrimary() {
285 switch (CurTok) {
286 default:
287 return LogError("unknown token when expecting an expression");
288 case tok_identifier:
289 return ParseIdentifierExpr();
290 case tok_number:
291 return ParseNumberExpr();
292 case '(':
293 return ParseParenExpr();
294 }
295 }
296
297 /// binoprhs
298 /// ::= ('+' primary)*
ParseBinOpRHS(int ExprPrec,std::unique_ptr<ExprAST> LHS)299 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
300 std::unique_ptr<ExprAST> LHS) {
301 // If this is a binop, find its precedence.
302 while (true) {
303 int TokPrec = GetTokPrecedence();
304
305 // If this is a binop that binds at least as tightly as the current binop,
306 // consume it, otherwise we are done.
307 if (TokPrec < ExprPrec)
308 return LHS;
309
310 // Okay, we know this is a binop.
311 int BinOp = CurTok;
312 getNextToken(); // eat binop
313
314 // Parse the primary expression after the binary operator.
315 auto RHS = ParsePrimary();
316 if (!RHS)
317 return nullptr;
318
319 // If BinOp binds less tightly with RHS than the operator after RHS, let
320 // the pending operator take RHS as its LHS.
321 int NextPrec = GetTokPrecedence();
322 if (TokPrec < NextPrec) {
323 RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
324 if (!RHS)
325 return nullptr;
326 }
327
328 // Merge LHS/RHS.
329 LHS =
330 llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
331 }
332 }
333
334 /// expression
335 /// ::= primary binoprhs
336 ///
ParseExpression()337 static std::unique_ptr<ExprAST> ParseExpression() {
338 auto LHS = ParsePrimary();
339 if (!LHS)
340 return nullptr;
341
342 return ParseBinOpRHS(0, std::move(LHS));
343 }
344
345 /// prototype
346 /// ::= id '(' id* ')'
ParsePrototype()347 static std::unique_ptr<PrototypeAST> ParsePrototype() {
348 if (CurTok != tok_identifier)
349 return LogErrorP("Expected function name in prototype");
350
351 std::string FnName = IdentifierStr;
352 getNextToken();
353
354 if (CurTok != '(')
355 return LogErrorP("Expected '(' in prototype");
356
357 std::vector<std::string> ArgNames;
358 while (getNextToken() == tok_identifier)
359 ArgNames.push_back(IdentifierStr);
360 if (CurTok != ')')
361 return LogErrorP("Expected ')' in prototype");
362
363 // success.
364 getNextToken(); // eat ')'.
365
366 return llvm::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
367 }
368
369 /// definition ::= 'def' prototype expression
ParseDefinition()370 static std::unique_ptr<FunctionAST> ParseDefinition() {
371 getNextToken(); // eat def.
372 auto Proto = ParsePrototype();
373 if (!Proto)
374 return nullptr;
375
376 if (auto E = ParseExpression())
377 return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
378 return nullptr;
379 }
380
381 /// toplevelexpr ::= expression
ParseTopLevelExpr()382 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
383 if (auto E = ParseExpression()) {
384 // Make an anonymous proto.
385 auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
386 std::vector<std::string>());
387 return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
388 }
389 return nullptr;
390 }
391
392 /// external ::= 'extern' prototype
ParseExtern()393 static std::unique_ptr<PrototypeAST> ParseExtern() {
394 getNextToken(); // eat extern.
395 return ParsePrototype();
396 }
397
398 //===----------------------------------------------------------------------===//
399 // Code Generation
400 //===----------------------------------------------------------------------===//
401
402 static LLVMContext TheContext;
403 static IRBuilder<> Builder(TheContext);
404 static std::unique_ptr<Module> TheModule;
405 static std::map<std::string, Value *> NamedValues;
406
LogErrorV(const char * Str)407 Value *LogErrorV(const char *Str) {
408 LogError(Str);
409 return nullptr;
410 }
411
codegen()412 Value *NumberExprAST::codegen() {
413 return ConstantFP::get(TheContext, APFloat(Val));
414 }
415
codegen()416 Value *VariableExprAST::codegen() {
417 // Look this variable up in the function.
418 Value *V = NamedValues[Name];
419 if (!V)
420 return LogErrorV("Unknown variable name");
421 return V;
422 }
423
codegen()424 Value *BinaryExprAST::codegen() {
425 Value *L = LHS->codegen();
426 Value *R = RHS->codegen();
427 if (!L || !R)
428 return nullptr;
429
430 switch (Op) {
431 case '+':
432 return Builder.CreateFAdd(L, R, "addtmp");
433 case '-':
434 return Builder.CreateFSub(L, R, "subtmp");
435 case '*':
436 return Builder.CreateFMul(L, R, "multmp");
437 case '<':
438 L = Builder.CreateFCmpULT(L, R, "cmptmp");
439 // Convert bool 0/1 to double 0.0 or 1.0
440 return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp");
441 default:
442 return LogErrorV("invalid binary operator");
443 }
444 }
445
codegen()446 Value *CallExprAST::codegen() {
447 // Look up the name in the global module table.
448 Function *CalleeF = TheModule->getFunction(Callee);
449 if (!CalleeF)
450 return LogErrorV("Unknown function referenced");
451
452 // If argument mismatch error.
453 if (CalleeF->arg_size() != Args.size())
454 return LogErrorV("Incorrect # arguments passed");
455
456 std::vector<Value *> ArgsV;
457 for (unsigned i = 0, e = Args.size(); i != e; ++i) {
458 ArgsV.push_back(Args[i]->codegen());
459 if (!ArgsV.back())
460 return nullptr;
461 }
462
463 return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
464 }
465
codegen()466 Function *PrototypeAST::codegen() {
467 // Make the function type: double(double,double) etc.
468 std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(TheContext));
469 FunctionType *FT =
470 FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false);
471
472 Function *F =
473 Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
474
475 // Set names for all arguments.
476 unsigned Idx = 0;
477 for (auto &Arg : F->args())
478 Arg.setName(Args[Idx++]);
479
480 return F;
481 }
482
codegen()483 Function *FunctionAST::codegen() {
484 // First, check for an existing function from a previous 'extern' declaration.
485 Function *TheFunction = TheModule->getFunction(Proto->getName());
486
487 if (!TheFunction)
488 TheFunction = Proto->codegen();
489
490 if (!TheFunction)
491 return nullptr;
492
493 // Create a new basic block to start insertion into.
494 BasicBlock *BB = BasicBlock::Create(TheContext, "entry", TheFunction);
495 Builder.SetInsertPoint(BB);
496
497 // Record the function arguments in the NamedValues map.
498 NamedValues.clear();
499 for (auto &Arg : TheFunction->args())
500 NamedValues[Arg.getName()] = &Arg;
501
502 if (Value *RetVal = Body->codegen()) {
503 // Finish off the function.
504 Builder.CreateRet(RetVal);
505
506 // Validate the generated code, checking for consistency.
507 verifyFunction(*TheFunction);
508
509 return TheFunction;
510 }
511
512 // Error reading body, remove function.
513 TheFunction->eraseFromParent();
514 return nullptr;
515 }
516
517 //===----------------------------------------------------------------------===//
518 // Top-Level parsing and JIT Driver
519 //===----------------------------------------------------------------------===//
520
HandleDefinition()521 static void HandleDefinition() {
522 if (auto FnAST = ParseDefinition()) {
523 if (auto *FnIR = FnAST->codegen()) {
524 fprintf(stderr, "Read function definition:");
525 FnIR->print(errs());
526 fprintf(stderr, "\n");
527 }
528 } else {
529 // Skip token for error recovery.
530 getNextToken();
531 }
532 }
533
HandleExtern()534 static void HandleExtern() {
535 if (auto ProtoAST = ParseExtern()) {
536 if (auto *FnIR = ProtoAST->codegen()) {
537 fprintf(stderr, "Read extern: ");
538 FnIR->print(errs());
539 fprintf(stderr, "\n");
540 }
541 } else {
542 // Skip token for error recovery.
543 getNextToken();
544 }
545 }
546
HandleTopLevelExpression()547 static void HandleTopLevelExpression() {
548 // Evaluate a top-level expression into an anonymous function.
549 if (auto FnAST = ParseTopLevelExpr()) {
550 if (auto *FnIR = FnAST->codegen()) {
551 fprintf(stderr, "Read top-level expression:");
552 FnIR->print(errs());
553 fprintf(stderr, "\n");
554 }
555 } else {
556 // Skip token for error recovery.
557 getNextToken();
558 }
559 }
560
561 /// top ::= definition | external | expression | ';'
MainLoop()562 static void MainLoop() {
563 while (true) {
564 fprintf(stderr, "ready> ");
565 switch (CurTok) {
566 case tok_eof:
567 return;
568 case ';': // ignore top-level semicolons.
569 getNextToken();
570 break;
571 case tok_def:
572 HandleDefinition();
573 break;
574 case tok_extern:
575 HandleExtern();
576 break;
577 default:
578 HandleTopLevelExpression();
579 break;
580 }
581 }
582 }
583
584 //===----------------------------------------------------------------------===//
585 // Main driver code.
586 //===----------------------------------------------------------------------===//
587
main()588 int main() {
589 // Install standard binary operators.
590 // 1 is lowest precedence.
591 BinopPrecedence['<'] = 10;
592 BinopPrecedence['+'] = 20;
593 BinopPrecedence['-'] = 20;
594 BinopPrecedence['*'] = 40; // highest.
595
596 // Prime the first token.
597 fprintf(stderr, "ready> ");
598 getNextToken();
599
600 // Make the module, which holds all the code.
601 TheModule = llvm::make_unique<Module>("my cool jit", TheContext);
602
603 // Run the main "interpreter loop" now.
604 MainLoop();
605
606 // Print out all of the generated code.
607 TheModule->print(errs(), nullptr);
608
609 return 0;
610 }
611