1 /*
2   Copyright (c) 2011-2021, Intel Corporation
3   All rights reserved.
4 
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions are
7   met:
8 
9     * Redistributions of source code must retain the above copyright
10       notice, this list of conditions and the following disclaimer.
11 
12     * Redistributions in binary form must reproduce the above copyright
13       notice, this list of conditions and the following disclaimer in the
14       documentation and/or other materials provided with the distribution.
15 
16     * Neither the name of Intel Corporation nor the names of its
17       contributors may be used to endorse or promote products derived from
18       this software without specific prior written permission.
19 
20 
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
22    IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
23    TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
24    PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
25    OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33 
34 /** @file ast.cpp
35 
36     @brief General functionality related to abstract syntax trees and
37     traversal of them.
38  */
39 
40 #include "ast.h"
41 #include "expr.h"
42 #include "func.h"
43 #include "stmt.h"
44 #include "sym.h"
45 #include "util.h"
46 
47 #include <llvm/Support/TimeProfiler.h>
48 
49 using namespace ispc;
50 
51 ///////////////////////////////////////////////////////////////////////////
52 // ASTNode
53 
~ASTNode()54 ASTNode::~ASTNode() {}
55 
56 ///////////////////////////////////////////////////////////////////////////
57 // AST
58 
AddFunction(Symbol * sym,Stmt * code)59 void AST::AddFunction(Symbol *sym, Stmt *code) {
60     if (sym == NULL)
61         return;
62     functions.push_back(new Function(sym, code));
63 }
64 
GenerateIR()65 void AST::GenerateIR() {
66     llvm::TimeTraceScope TimeScope("GenerateIR");
67     for (unsigned int i = 0; i < functions.size(); ++i)
68         functions[i]->GenerateIR();
69 }
70 
71 ///////////////////////////////////////////////////////////////////////////
72 
WalkAST(ASTNode * node,ASTPreCallBackFunc preFunc,ASTPostCallBackFunc postFunc,void * data)73 ASTNode *ispc::WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, void *data) {
74     if (node == NULL)
75         return node;
76 
77     // Call the callback function
78     if (preFunc != NULL) {
79         if (preFunc(node, data) == false)
80             // The function asked us to not continue recursively, so stop.
81             return node;
82     }
83 
84     ////////////////////////////////////////////////////////////////////////////
85     // Handle Statements
86     if (llvm::dyn_cast<Stmt>(node) != NULL) {
87         ExprStmt *es;
88         DeclStmt *ds;
89         IfStmt *is;
90         DoStmt *dos;
91         ForStmt *fs;
92         ForeachStmt *fes;
93         ForeachActiveStmt *fas;
94         ForeachUniqueStmt *fus;
95         CaseStmt *cs;
96         DefaultStmt *defs;
97         SwitchStmt *ss;
98         ReturnStmt *rs;
99         LabeledStmt *ls;
100         StmtList *sl;
101         PrintStmt *ps;
102         AssertStmt *as;
103         DeleteStmt *dels;
104         UnmaskedStmt *ums;
105 
106         if ((es = llvm::dyn_cast<ExprStmt>(node)) != NULL)
107             es->expr = (Expr *)WalkAST(es->expr, preFunc, postFunc, data);
108         else if ((ds = llvm::dyn_cast<DeclStmt>(node)) != NULL) {
109             for (unsigned int i = 0; i < ds->vars.size(); ++i)
110                 ds->vars[i].init = (Expr *)WalkAST(ds->vars[i].init, preFunc, postFunc, data);
111         } else if ((is = llvm::dyn_cast<IfStmt>(node)) != NULL) {
112             is->test = (Expr *)WalkAST(is->test, preFunc, postFunc, data);
113             is->trueStmts = (Stmt *)WalkAST(is->trueStmts, preFunc, postFunc, data);
114             is->falseStmts = (Stmt *)WalkAST(is->falseStmts, preFunc, postFunc, data);
115         } else if ((dos = llvm::dyn_cast<DoStmt>(node)) != NULL) {
116             dos->testExpr = (Expr *)WalkAST(dos->testExpr, preFunc, postFunc, data);
117             dos->bodyStmts = (Stmt *)WalkAST(dos->bodyStmts, preFunc, postFunc, data);
118         } else if ((fs = llvm::dyn_cast<ForStmt>(node)) != NULL) {
119             fs->init = (Stmt *)WalkAST(fs->init, preFunc, postFunc, data);
120             fs->test = (Expr *)WalkAST(fs->test, preFunc, postFunc, data);
121             fs->step = (Stmt *)WalkAST(fs->step, preFunc, postFunc, data);
122             fs->stmts = (Stmt *)WalkAST(fs->stmts, preFunc, postFunc, data);
123         } else if ((fes = llvm::dyn_cast<ForeachStmt>(node)) != NULL) {
124             for (unsigned int i = 0; i < fes->startExprs.size(); ++i)
125                 fes->startExprs[i] = (Expr *)WalkAST(fes->startExprs[i], preFunc, postFunc, data);
126             for (unsigned int i = 0; i < fes->endExprs.size(); ++i)
127                 fes->endExprs[i] = (Expr *)WalkAST(fes->endExprs[i], preFunc, postFunc, data);
128             fes->stmts = (Stmt *)WalkAST(fes->stmts, preFunc, postFunc, data);
129         } else if ((fas = llvm::dyn_cast<ForeachActiveStmt>(node)) != NULL) {
130             fas->stmts = (Stmt *)WalkAST(fas->stmts, preFunc, postFunc, data);
131         } else if ((fus = llvm::dyn_cast<ForeachUniqueStmt>(node)) != NULL) {
132             fus->expr = (Expr *)WalkAST(fus->expr, preFunc, postFunc, data);
133             fus->stmts = (Stmt *)WalkAST(fus->stmts, preFunc, postFunc, data);
134         } else if ((cs = llvm::dyn_cast<CaseStmt>(node)) != NULL)
135             cs->stmts = (Stmt *)WalkAST(cs->stmts, preFunc, postFunc, data);
136         else if ((defs = llvm::dyn_cast<DefaultStmt>(node)) != NULL)
137             defs->stmts = (Stmt *)WalkAST(defs->stmts, preFunc, postFunc, data);
138         else if ((ss = llvm::dyn_cast<SwitchStmt>(node)) != NULL) {
139             ss->expr = (Expr *)WalkAST(ss->expr, preFunc, postFunc, data);
140             ss->stmts = (Stmt *)WalkAST(ss->stmts, preFunc, postFunc, data);
141         } else if (llvm::dyn_cast<BreakStmt>(node) != NULL || llvm::dyn_cast<ContinueStmt>(node) != NULL ||
142                    llvm::dyn_cast<GotoStmt>(node) != NULL) {
143             // nothing
144         } else if ((ls = llvm::dyn_cast<LabeledStmt>(node)) != NULL)
145             ls->stmt = (Stmt *)WalkAST(ls->stmt, preFunc, postFunc, data);
146         else if ((rs = llvm::dyn_cast<ReturnStmt>(node)) != NULL)
147             rs->expr = (Expr *)WalkAST(rs->expr, preFunc, postFunc, data);
148         else if ((sl = llvm::dyn_cast<StmtList>(node)) != NULL) {
149             std::vector<Stmt *> &sls = sl->stmts;
150             for (unsigned int i = 0; i < sls.size(); ++i)
151                 sls[i] = (Stmt *)WalkAST(sls[i], preFunc, postFunc, data);
152         } else if ((ps = llvm::dyn_cast<PrintStmt>(node)) != NULL)
153             ps->values = (Expr *)WalkAST(ps->values, preFunc, postFunc, data);
154         else if ((as = llvm::dyn_cast<AssertStmt>(node)) != NULL)
155             as->expr = (Expr *)WalkAST(as->expr, preFunc, postFunc, data);
156         else if ((dels = llvm::dyn_cast<DeleteStmt>(node)) != NULL)
157             dels->expr = (Expr *)WalkAST(dels->expr, preFunc, postFunc, data);
158         else if ((ums = llvm::dyn_cast<UnmaskedStmt>(node)) != NULL)
159             ums->stmts = (Stmt *)WalkAST(ums->stmts, preFunc, postFunc, data);
160         else
161             FATAL("Unhandled statement type in WalkAST()");
162     } else {
163         ///////////////////////////////////////////////////////////////////////////
164         // Handle expressions
165         Assert(llvm::dyn_cast<Expr>(node) != NULL);
166         UnaryExpr *ue;
167         BinaryExpr *be;
168         AssignExpr *ae;
169         SelectExpr *se;
170         ExprList *el;
171         FunctionCallExpr *fce;
172         IndexExpr *ie;
173         MemberExpr *me;
174         TypeCastExpr *tce;
175         ReferenceExpr *re;
176         PtrDerefExpr *ptrderef;
177         RefDerefExpr *refderef;
178         SizeOfExpr *soe;
179         AddressOfExpr *aoe;
180         NewExpr *newe;
181         AllocaExpr *alloce;
182 
183         if ((ue = llvm::dyn_cast<UnaryExpr>(node)) != NULL)
184             ue->expr = (Expr *)WalkAST(ue->expr, preFunc, postFunc, data);
185         else if ((be = llvm::dyn_cast<BinaryExpr>(node)) != NULL) {
186             be->arg0 = (Expr *)WalkAST(be->arg0, preFunc, postFunc, data);
187             be->arg1 = (Expr *)WalkAST(be->arg1, preFunc, postFunc, data);
188         } else if ((ae = llvm::dyn_cast<AssignExpr>(node)) != NULL) {
189             ae->lvalue = (Expr *)WalkAST(ae->lvalue, preFunc, postFunc, data);
190             ae->rvalue = (Expr *)WalkAST(ae->rvalue, preFunc, postFunc, data);
191         } else if ((se = llvm::dyn_cast<SelectExpr>(node)) != NULL) {
192             se->test = (Expr *)WalkAST(se->test, preFunc, postFunc, data);
193             se->expr1 = (Expr *)WalkAST(se->expr1, preFunc, postFunc, data);
194             se->expr2 = (Expr *)WalkAST(se->expr2, preFunc, postFunc, data);
195         } else if ((el = llvm::dyn_cast<ExprList>(node)) != NULL) {
196             for (unsigned int i = 0; i < el->exprs.size(); ++i)
197                 el->exprs[i] = (Expr *)WalkAST(el->exprs[i], preFunc, postFunc, data);
198         } else if ((fce = llvm::dyn_cast<FunctionCallExpr>(node)) != NULL) {
199             fce->func = (Expr *)WalkAST(fce->func, preFunc, postFunc, data);
200             fce->args = (ExprList *)WalkAST(fce->args, preFunc, postFunc, data);
201             for (int k = 0; k < 3; k++)
202                 fce->launchCountExpr[k] = (Expr *)WalkAST(fce->launchCountExpr[k], preFunc, postFunc, data);
203         } else if ((ie = llvm::dyn_cast<IndexExpr>(node)) != NULL) {
204             ie->baseExpr = (Expr *)WalkAST(ie->baseExpr, preFunc, postFunc, data);
205             ie->index = (Expr *)WalkAST(ie->index, preFunc, postFunc, data);
206         } else if ((me = llvm::dyn_cast<MemberExpr>(node)) != NULL)
207             me->expr = (Expr *)WalkAST(me->expr, preFunc, postFunc, data);
208         else if ((tce = llvm::dyn_cast<TypeCastExpr>(node)) != NULL)
209             tce->expr = (Expr *)WalkAST(tce->expr, preFunc, postFunc, data);
210         else if ((re = llvm::dyn_cast<ReferenceExpr>(node)) != NULL)
211             re->expr = (Expr *)WalkAST(re->expr, preFunc, postFunc, data);
212         else if ((ptrderef = llvm::dyn_cast<PtrDerefExpr>(node)) != NULL)
213             ptrderef->expr = (Expr *)WalkAST(ptrderef->expr, preFunc, postFunc, data);
214         else if ((refderef = llvm::dyn_cast<RefDerefExpr>(node)) != NULL)
215             refderef->expr = (Expr *)WalkAST(refderef->expr, preFunc, postFunc, data);
216         else if ((soe = llvm::dyn_cast<SizeOfExpr>(node)) != NULL)
217             soe->expr = (Expr *)WalkAST(soe->expr, preFunc, postFunc, data);
218         else if ((alloce = llvm::dyn_cast<AllocaExpr>(node)) != NULL)
219             alloce->expr = (Expr *)WalkAST(alloce->expr, preFunc, postFunc, data);
220         else if ((aoe = llvm::dyn_cast<AddressOfExpr>(node)) != NULL)
221             aoe->expr = (Expr *)WalkAST(aoe->expr, preFunc, postFunc, data);
222         else if ((newe = llvm::dyn_cast<NewExpr>(node)) != NULL) {
223             newe->countExpr = (Expr *)WalkAST(newe->countExpr, preFunc, postFunc, data);
224             newe->initExpr = (Expr *)WalkAST(newe->initExpr, preFunc, postFunc, data);
225         } else if (llvm::dyn_cast<SymbolExpr>(node) != NULL || llvm::dyn_cast<ConstExpr>(node) != NULL ||
226                    llvm::dyn_cast<FunctionSymbolExpr>(node) != NULL || llvm::dyn_cast<SyncExpr>(node) != NULL ||
227                    llvm::dyn_cast<NullPointerExpr>(node) != NULL) {
228             // nothing to do
229         } else
230             FATAL("Unhandled expression type in WalkAST().");
231     }
232 
233     // Call the callback function
234     if (postFunc != NULL)
235         return postFunc(node, data);
236     else
237         return node;
238 }
239 
lOptimizeNode(ASTNode * node,void *)240 static ASTNode *lOptimizeNode(ASTNode *node, void *) { return node->Optimize(); }
241 
Optimize(ASTNode * root)242 ASTNode *ispc::Optimize(ASTNode *root) { return WalkAST(root, NULL, lOptimizeNode, NULL); }
243 
Optimize(Expr * expr)244 Expr *ispc::Optimize(Expr *expr) { return (Expr *)Optimize((ASTNode *)expr); }
245 
Optimize(Stmt * stmt)246 Stmt *ispc::Optimize(Stmt *stmt) { return (Stmt *)Optimize((ASTNode *)stmt); }
247 
lTypeCheckNode(ASTNode * node,void *)248 static ASTNode *lTypeCheckNode(ASTNode *node, void *) { return node->TypeCheck(); }
249 
TypeCheck(ASTNode * root)250 ASTNode *ispc::TypeCheck(ASTNode *root) { return WalkAST(root, NULL, lTypeCheckNode, NULL); }
251 
TypeCheck(Expr * expr)252 Expr *ispc::TypeCheck(Expr *expr) { return (Expr *)TypeCheck((ASTNode *)expr); }
253 
TypeCheck(Stmt * stmt)254 Stmt *ispc::TypeCheck(Stmt *stmt) { return (Stmt *)TypeCheck((ASTNode *)stmt); }
255 
256 struct CostData {
CostDataCostData257     CostData() { cost = foreachDepth = 0; }
258 
259     int cost;
260     int foreachDepth;
261 };
262 
lCostCallbackPre(ASTNode * node,void * d)263 static bool lCostCallbackPre(ASTNode *node, void *d) {
264     CostData *data = (CostData *)d;
265     if (llvm::dyn_cast<ForeachStmt>(node) != NULL)
266         ++data->foreachDepth;
267     if (data->foreachDepth == 0)
268         data->cost += node->EstimateCost();
269     return true;
270 }
271 
lCostCallbackPost(ASTNode * node,void * d)272 static ASTNode *lCostCallbackPost(ASTNode *node, void *d) {
273     CostData *data = (CostData *)d;
274     if (llvm::dyn_cast<ForeachStmt>(node) != NULL)
275         --data->foreachDepth;
276     return node;
277 }
278 
EstimateCost(ASTNode * root)279 int ispc::EstimateCost(ASTNode *root) {
280     CostData data;
281     WalkAST(root, lCostCallbackPre, lCostCallbackPost, &data);
282     return data.cost;
283 }
284 
285 /** Given an AST node, check to see if it's safe if we happen to run the
286     code for that node with the execution mask all off.
287  */
lCheckAllOffSafety(ASTNode * node,void * data)288 static bool lCheckAllOffSafety(ASTNode *node, void *data) {
289     bool *okPtr = (bool *)data;
290 
291     FunctionCallExpr *fce;
292     if ((fce = llvm::dyn_cast<FunctionCallExpr>(node)) != NULL) {
293         if (fce->func == NULL)
294             return false;
295 
296         const Type *type = fce->func->GetType();
297         const PointerType *pt = CastType<PointerType>(type);
298         if (pt != NULL)
299             type = pt->GetBaseType();
300         const FunctionType *ftype = CastType<FunctionType>(type);
301         Assert(ftype != NULL);
302 
303         if (ftype->isSafe == false) {
304             *okPtr = false;
305             return false;
306         }
307     }
308 
309     if (llvm::dyn_cast<AssertStmt>(node) != NULL) {
310         // While it's fine to run the assert for varying tests, it's not
311         // desirable to check an assert on a uniform variable if all of the
312         // lanes are off.
313         *okPtr = false;
314         return false;
315     }
316 
317     if (llvm::dyn_cast<PrintStmt>(node) != NULL) {
318         *okPtr = false;
319         return false;
320     }
321 
322     if (llvm::dyn_cast<NewExpr>(node) != NULL || llvm::dyn_cast<DeleteStmt>(node) != NULL) {
323         // We definitely don't want to run the uniform variants of these if
324         // the mask is all off.  It's also worth skipping the overhead of
325         // executing the varying versions of them in the all-off mask case.
326         *okPtr = false;
327         return false;
328     }
329 
330     if (llvm::dyn_cast<ForeachStmt>(node) != NULL || llvm::dyn_cast<ForeachActiveStmt>(node) != NULL ||
331         llvm::dyn_cast<ForeachUniqueStmt>(node) != NULL || llvm::dyn_cast<UnmaskedStmt>(node) != NULL) {
332         // The various foreach statements also shouldn't be run with an
333         // all-off mask.  Since they can re-establish an 'all on' mask,
334         // this would be pretty unintuitive.  (More generally, it's
335         // possibly a little strange to allow foreach in the presence of
336         // any non-uniform control flow...)
337         //
338         // Similarly, the implementation of foreach_unique assumes as a
339         // precondition that the mask won't be all off going into it, so
340         // we'll enforce that here...
341         *okPtr = false;
342         return false;
343     }
344 
345     BinaryExpr *binaryExpr;
346     if ((binaryExpr = llvm::dyn_cast<BinaryExpr>(node)) != NULL) {
347         if (binaryExpr->op == BinaryExpr::Mod || binaryExpr->op == BinaryExpr::Div) {
348             *okPtr = false;
349             return false;
350         }
351     }
352     IndexExpr *ie;
353     if ((ie = llvm::dyn_cast<IndexExpr>(node)) != NULL && ie->baseExpr != NULL) {
354         const Type *type = ie->baseExpr->GetType();
355         if (type == NULL)
356             return true;
357         if (CastType<ReferenceType>(type) != NULL)
358             type = type->GetReferenceTarget();
359 
360         ConstExpr *ce = llvm::dyn_cast<ConstExpr>(ie->index);
361         if (ce == NULL) {
362             // indexing with a variable... -> not safe
363             *okPtr = false;
364             return false;
365         }
366 
367         const PointerType *pointerType = CastType<PointerType>(type);
368         if (pointerType != NULL) {
369             // pointer[index] -> can't be sure -> not safe
370             *okPtr = false;
371             return false;
372         }
373 
374         const SequentialType *seqType = CastType<SequentialType>(type);
375         Assert(seqType != NULL);
376         int nElements = seqType->GetElementCount();
377         if (nElements == 0) {
378             // Unsized array, so we can't be sure -> not safe
379             *okPtr = false;
380             return false;
381         }
382 
383         int32_t indices[ISPC_MAX_NVEC];
384         int count = ce->GetValues(indices);
385         for (int i = 0; i < count; ++i) {
386             if (indices[i] < 0 || indices[i] >= nElements) {
387                 // Index is out of bounds -> not safe
388                 *okPtr = false;
389                 return false;
390             }
391         }
392 
393         // All indices are in-bounds
394         return true;
395     }
396 
397     MemberExpr *me;
398     if ((me = llvm::dyn_cast<MemberExpr>(node)) != NULL && me->dereferenceExpr) {
399         *okPtr = false;
400         return false;
401     }
402 
403     if (llvm::dyn_cast<PtrDerefExpr>(node) != NULL) {
404         *okPtr = false;
405         return false;
406     }
407 
408     /*
409       Don't allow turning if/else to straight-line-code if we
410       assign to a uniform.
411     */
412     AssignExpr *ae;
413     if ((ae = llvm::dyn_cast<AssignExpr>(node)) != NULL) {
414         if (ae->GetType()) {
415             if (ae->GetType()->IsUniformType()) {
416                 *okPtr = false;
417                 return false;
418             }
419         }
420     }
421 
422     return true;
423 }
424 
SafeToRunWithMaskAllOff(ASTNode * root)425 bool ispc::SafeToRunWithMaskAllOff(ASTNode *root) {
426     bool safe = true;
427     WalkAST(root, lCheckAllOffSafety, NULL, &safe);
428     return safe;
429 }
430