1 /*
2 Copyright (c) 2011-2021, Intel Corporation
3 All rights reserved.
4
5 Redistribution and use in source and binary forms, with or without
6 modification, are permitted provided that the following conditions are
7 met:
8
9 * Redistributions of source code must retain the above copyright
10 notice, this list of conditions and the following disclaimer.
11
12 * Redistributions in binary form must reproduce the above copyright
13 notice, this list of conditions and the following disclaimer in the
14 documentation and/or other materials provided with the distribution.
15
16 * Neither the name of Intel Corporation nor the names of its
17 contributors may be used to endorse or promote products derived from
18 this software without specific prior written permission.
19
20
21 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
22 IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
23 TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
24 PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
25 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 /** @file ast.cpp
35
36 @brief General functionality related to abstract syntax trees and
37 traversal of them.
38 */
39
40 #include "ast.h"
41 #include "expr.h"
42 #include "func.h"
43 #include "stmt.h"
44 #include "sym.h"
45 #include "util.h"
46
47 #include <llvm/Support/TimeProfiler.h>
48
49 using namespace ispc;
50
51 ///////////////////////////////////////////////////////////////////////////
52 // ASTNode
53
~ASTNode()54 ASTNode::~ASTNode() {}
55
56 ///////////////////////////////////////////////////////////////////////////
57 // AST
58
AddFunction(Symbol * sym,Stmt * code)59 void AST::AddFunction(Symbol *sym, Stmt *code) {
60 if (sym == NULL)
61 return;
62 functions.push_back(new Function(sym, code));
63 }
64
GenerateIR()65 void AST::GenerateIR() {
66 llvm::TimeTraceScope TimeScope("GenerateIR");
67 for (unsigned int i = 0; i < functions.size(); ++i)
68 functions[i]->GenerateIR();
69 }
70
71 ///////////////////////////////////////////////////////////////////////////
72
WalkAST(ASTNode * node,ASTPreCallBackFunc preFunc,ASTPostCallBackFunc postFunc,void * data)73 ASTNode *ispc::WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, void *data) {
74 if (node == NULL)
75 return node;
76
77 // Call the callback function
78 if (preFunc != NULL) {
79 if (preFunc(node, data) == false)
80 // The function asked us to not continue recursively, so stop.
81 return node;
82 }
83
84 ////////////////////////////////////////////////////////////////////////////
85 // Handle Statements
86 if (llvm::dyn_cast<Stmt>(node) != NULL) {
87 ExprStmt *es;
88 DeclStmt *ds;
89 IfStmt *is;
90 DoStmt *dos;
91 ForStmt *fs;
92 ForeachStmt *fes;
93 ForeachActiveStmt *fas;
94 ForeachUniqueStmt *fus;
95 CaseStmt *cs;
96 DefaultStmt *defs;
97 SwitchStmt *ss;
98 ReturnStmt *rs;
99 LabeledStmt *ls;
100 StmtList *sl;
101 PrintStmt *ps;
102 AssertStmt *as;
103 DeleteStmt *dels;
104 UnmaskedStmt *ums;
105
106 if ((es = llvm::dyn_cast<ExprStmt>(node)) != NULL)
107 es->expr = (Expr *)WalkAST(es->expr, preFunc, postFunc, data);
108 else if ((ds = llvm::dyn_cast<DeclStmt>(node)) != NULL) {
109 for (unsigned int i = 0; i < ds->vars.size(); ++i)
110 ds->vars[i].init = (Expr *)WalkAST(ds->vars[i].init, preFunc, postFunc, data);
111 } else if ((is = llvm::dyn_cast<IfStmt>(node)) != NULL) {
112 is->test = (Expr *)WalkAST(is->test, preFunc, postFunc, data);
113 is->trueStmts = (Stmt *)WalkAST(is->trueStmts, preFunc, postFunc, data);
114 is->falseStmts = (Stmt *)WalkAST(is->falseStmts, preFunc, postFunc, data);
115 } else if ((dos = llvm::dyn_cast<DoStmt>(node)) != NULL) {
116 dos->testExpr = (Expr *)WalkAST(dos->testExpr, preFunc, postFunc, data);
117 dos->bodyStmts = (Stmt *)WalkAST(dos->bodyStmts, preFunc, postFunc, data);
118 } else if ((fs = llvm::dyn_cast<ForStmt>(node)) != NULL) {
119 fs->init = (Stmt *)WalkAST(fs->init, preFunc, postFunc, data);
120 fs->test = (Expr *)WalkAST(fs->test, preFunc, postFunc, data);
121 fs->step = (Stmt *)WalkAST(fs->step, preFunc, postFunc, data);
122 fs->stmts = (Stmt *)WalkAST(fs->stmts, preFunc, postFunc, data);
123 } else if ((fes = llvm::dyn_cast<ForeachStmt>(node)) != NULL) {
124 for (unsigned int i = 0; i < fes->startExprs.size(); ++i)
125 fes->startExprs[i] = (Expr *)WalkAST(fes->startExprs[i], preFunc, postFunc, data);
126 for (unsigned int i = 0; i < fes->endExprs.size(); ++i)
127 fes->endExprs[i] = (Expr *)WalkAST(fes->endExprs[i], preFunc, postFunc, data);
128 fes->stmts = (Stmt *)WalkAST(fes->stmts, preFunc, postFunc, data);
129 } else if ((fas = llvm::dyn_cast<ForeachActiveStmt>(node)) != NULL) {
130 fas->stmts = (Stmt *)WalkAST(fas->stmts, preFunc, postFunc, data);
131 } else if ((fus = llvm::dyn_cast<ForeachUniqueStmt>(node)) != NULL) {
132 fus->expr = (Expr *)WalkAST(fus->expr, preFunc, postFunc, data);
133 fus->stmts = (Stmt *)WalkAST(fus->stmts, preFunc, postFunc, data);
134 } else if ((cs = llvm::dyn_cast<CaseStmt>(node)) != NULL)
135 cs->stmts = (Stmt *)WalkAST(cs->stmts, preFunc, postFunc, data);
136 else if ((defs = llvm::dyn_cast<DefaultStmt>(node)) != NULL)
137 defs->stmts = (Stmt *)WalkAST(defs->stmts, preFunc, postFunc, data);
138 else if ((ss = llvm::dyn_cast<SwitchStmt>(node)) != NULL) {
139 ss->expr = (Expr *)WalkAST(ss->expr, preFunc, postFunc, data);
140 ss->stmts = (Stmt *)WalkAST(ss->stmts, preFunc, postFunc, data);
141 } else if (llvm::dyn_cast<BreakStmt>(node) != NULL || llvm::dyn_cast<ContinueStmt>(node) != NULL ||
142 llvm::dyn_cast<GotoStmt>(node) != NULL) {
143 // nothing
144 } else if ((ls = llvm::dyn_cast<LabeledStmt>(node)) != NULL)
145 ls->stmt = (Stmt *)WalkAST(ls->stmt, preFunc, postFunc, data);
146 else if ((rs = llvm::dyn_cast<ReturnStmt>(node)) != NULL)
147 rs->expr = (Expr *)WalkAST(rs->expr, preFunc, postFunc, data);
148 else if ((sl = llvm::dyn_cast<StmtList>(node)) != NULL) {
149 std::vector<Stmt *> &sls = sl->stmts;
150 for (unsigned int i = 0; i < sls.size(); ++i)
151 sls[i] = (Stmt *)WalkAST(sls[i], preFunc, postFunc, data);
152 } else if ((ps = llvm::dyn_cast<PrintStmt>(node)) != NULL)
153 ps->values = (Expr *)WalkAST(ps->values, preFunc, postFunc, data);
154 else if ((as = llvm::dyn_cast<AssertStmt>(node)) != NULL)
155 as->expr = (Expr *)WalkAST(as->expr, preFunc, postFunc, data);
156 else if ((dels = llvm::dyn_cast<DeleteStmt>(node)) != NULL)
157 dels->expr = (Expr *)WalkAST(dels->expr, preFunc, postFunc, data);
158 else if ((ums = llvm::dyn_cast<UnmaskedStmt>(node)) != NULL)
159 ums->stmts = (Stmt *)WalkAST(ums->stmts, preFunc, postFunc, data);
160 else
161 FATAL("Unhandled statement type in WalkAST()");
162 } else {
163 ///////////////////////////////////////////////////////////////////////////
164 // Handle expressions
165 Assert(llvm::dyn_cast<Expr>(node) != NULL);
166 UnaryExpr *ue;
167 BinaryExpr *be;
168 AssignExpr *ae;
169 SelectExpr *se;
170 ExprList *el;
171 FunctionCallExpr *fce;
172 IndexExpr *ie;
173 MemberExpr *me;
174 TypeCastExpr *tce;
175 ReferenceExpr *re;
176 PtrDerefExpr *ptrderef;
177 RefDerefExpr *refderef;
178 SizeOfExpr *soe;
179 AddressOfExpr *aoe;
180 NewExpr *newe;
181 AllocaExpr *alloce;
182
183 if ((ue = llvm::dyn_cast<UnaryExpr>(node)) != NULL)
184 ue->expr = (Expr *)WalkAST(ue->expr, preFunc, postFunc, data);
185 else if ((be = llvm::dyn_cast<BinaryExpr>(node)) != NULL) {
186 be->arg0 = (Expr *)WalkAST(be->arg0, preFunc, postFunc, data);
187 be->arg1 = (Expr *)WalkAST(be->arg1, preFunc, postFunc, data);
188 } else if ((ae = llvm::dyn_cast<AssignExpr>(node)) != NULL) {
189 ae->lvalue = (Expr *)WalkAST(ae->lvalue, preFunc, postFunc, data);
190 ae->rvalue = (Expr *)WalkAST(ae->rvalue, preFunc, postFunc, data);
191 } else if ((se = llvm::dyn_cast<SelectExpr>(node)) != NULL) {
192 se->test = (Expr *)WalkAST(se->test, preFunc, postFunc, data);
193 se->expr1 = (Expr *)WalkAST(se->expr1, preFunc, postFunc, data);
194 se->expr2 = (Expr *)WalkAST(se->expr2, preFunc, postFunc, data);
195 } else if ((el = llvm::dyn_cast<ExprList>(node)) != NULL) {
196 for (unsigned int i = 0; i < el->exprs.size(); ++i)
197 el->exprs[i] = (Expr *)WalkAST(el->exprs[i], preFunc, postFunc, data);
198 } else if ((fce = llvm::dyn_cast<FunctionCallExpr>(node)) != NULL) {
199 fce->func = (Expr *)WalkAST(fce->func, preFunc, postFunc, data);
200 fce->args = (ExprList *)WalkAST(fce->args, preFunc, postFunc, data);
201 for (int k = 0; k < 3; k++)
202 fce->launchCountExpr[k] = (Expr *)WalkAST(fce->launchCountExpr[k], preFunc, postFunc, data);
203 } else if ((ie = llvm::dyn_cast<IndexExpr>(node)) != NULL) {
204 ie->baseExpr = (Expr *)WalkAST(ie->baseExpr, preFunc, postFunc, data);
205 ie->index = (Expr *)WalkAST(ie->index, preFunc, postFunc, data);
206 } else if ((me = llvm::dyn_cast<MemberExpr>(node)) != NULL)
207 me->expr = (Expr *)WalkAST(me->expr, preFunc, postFunc, data);
208 else if ((tce = llvm::dyn_cast<TypeCastExpr>(node)) != NULL)
209 tce->expr = (Expr *)WalkAST(tce->expr, preFunc, postFunc, data);
210 else if ((re = llvm::dyn_cast<ReferenceExpr>(node)) != NULL)
211 re->expr = (Expr *)WalkAST(re->expr, preFunc, postFunc, data);
212 else if ((ptrderef = llvm::dyn_cast<PtrDerefExpr>(node)) != NULL)
213 ptrderef->expr = (Expr *)WalkAST(ptrderef->expr, preFunc, postFunc, data);
214 else if ((refderef = llvm::dyn_cast<RefDerefExpr>(node)) != NULL)
215 refderef->expr = (Expr *)WalkAST(refderef->expr, preFunc, postFunc, data);
216 else if ((soe = llvm::dyn_cast<SizeOfExpr>(node)) != NULL)
217 soe->expr = (Expr *)WalkAST(soe->expr, preFunc, postFunc, data);
218 else if ((alloce = llvm::dyn_cast<AllocaExpr>(node)) != NULL)
219 alloce->expr = (Expr *)WalkAST(alloce->expr, preFunc, postFunc, data);
220 else if ((aoe = llvm::dyn_cast<AddressOfExpr>(node)) != NULL)
221 aoe->expr = (Expr *)WalkAST(aoe->expr, preFunc, postFunc, data);
222 else if ((newe = llvm::dyn_cast<NewExpr>(node)) != NULL) {
223 newe->countExpr = (Expr *)WalkAST(newe->countExpr, preFunc, postFunc, data);
224 newe->initExpr = (Expr *)WalkAST(newe->initExpr, preFunc, postFunc, data);
225 } else if (llvm::dyn_cast<SymbolExpr>(node) != NULL || llvm::dyn_cast<ConstExpr>(node) != NULL ||
226 llvm::dyn_cast<FunctionSymbolExpr>(node) != NULL || llvm::dyn_cast<SyncExpr>(node) != NULL ||
227 llvm::dyn_cast<NullPointerExpr>(node) != NULL) {
228 // nothing to do
229 } else
230 FATAL("Unhandled expression type in WalkAST().");
231 }
232
233 // Call the callback function
234 if (postFunc != NULL)
235 return postFunc(node, data);
236 else
237 return node;
238 }
239
lOptimizeNode(ASTNode * node,void *)240 static ASTNode *lOptimizeNode(ASTNode *node, void *) { return node->Optimize(); }
241
Optimize(ASTNode * root)242 ASTNode *ispc::Optimize(ASTNode *root) { return WalkAST(root, NULL, lOptimizeNode, NULL); }
243
Optimize(Expr * expr)244 Expr *ispc::Optimize(Expr *expr) { return (Expr *)Optimize((ASTNode *)expr); }
245
Optimize(Stmt * stmt)246 Stmt *ispc::Optimize(Stmt *stmt) { return (Stmt *)Optimize((ASTNode *)stmt); }
247
lTypeCheckNode(ASTNode * node,void *)248 static ASTNode *lTypeCheckNode(ASTNode *node, void *) { return node->TypeCheck(); }
249
TypeCheck(ASTNode * root)250 ASTNode *ispc::TypeCheck(ASTNode *root) { return WalkAST(root, NULL, lTypeCheckNode, NULL); }
251
TypeCheck(Expr * expr)252 Expr *ispc::TypeCheck(Expr *expr) { return (Expr *)TypeCheck((ASTNode *)expr); }
253
TypeCheck(Stmt * stmt)254 Stmt *ispc::TypeCheck(Stmt *stmt) { return (Stmt *)TypeCheck((ASTNode *)stmt); }
255
256 struct CostData {
CostDataCostData257 CostData() { cost = foreachDepth = 0; }
258
259 int cost;
260 int foreachDepth;
261 };
262
lCostCallbackPre(ASTNode * node,void * d)263 static bool lCostCallbackPre(ASTNode *node, void *d) {
264 CostData *data = (CostData *)d;
265 if (llvm::dyn_cast<ForeachStmt>(node) != NULL)
266 ++data->foreachDepth;
267 if (data->foreachDepth == 0)
268 data->cost += node->EstimateCost();
269 return true;
270 }
271
lCostCallbackPost(ASTNode * node,void * d)272 static ASTNode *lCostCallbackPost(ASTNode *node, void *d) {
273 CostData *data = (CostData *)d;
274 if (llvm::dyn_cast<ForeachStmt>(node) != NULL)
275 --data->foreachDepth;
276 return node;
277 }
278
EstimateCost(ASTNode * root)279 int ispc::EstimateCost(ASTNode *root) {
280 CostData data;
281 WalkAST(root, lCostCallbackPre, lCostCallbackPost, &data);
282 return data.cost;
283 }
284
285 /** Given an AST node, check to see if it's safe if we happen to run the
286 code for that node with the execution mask all off.
287 */
lCheckAllOffSafety(ASTNode * node,void * data)288 static bool lCheckAllOffSafety(ASTNode *node, void *data) {
289 bool *okPtr = (bool *)data;
290
291 FunctionCallExpr *fce;
292 if ((fce = llvm::dyn_cast<FunctionCallExpr>(node)) != NULL) {
293 if (fce->func == NULL)
294 return false;
295
296 const Type *type = fce->func->GetType();
297 const PointerType *pt = CastType<PointerType>(type);
298 if (pt != NULL)
299 type = pt->GetBaseType();
300 const FunctionType *ftype = CastType<FunctionType>(type);
301 Assert(ftype != NULL);
302
303 if (ftype->isSafe == false) {
304 *okPtr = false;
305 return false;
306 }
307 }
308
309 if (llvm::dyn_cast<AssertStmt>(node) != NULL) {
310 // While it's fine to run the assert for varying tests, it's not
311 // desirable to check an assert on a uniform variable if all of the
312 // lanes are off.
313 *okPtr = false;
314 return false;
315 }
316
317 if (llvm::dyn_cast<PrintStmt>(node) != NULL) {
318 *okPtr = false;
319 return false;
320 }
321
322 if (llvm::dyn_cast<NewExpr>(node) != NULL || llvm::dyn_cast<DeleteStmt>(node) != NULL) {
323 // We definitely don't want to run the uniform variants of these if
324 // the mask is all off. It's also worth skipping the overhead of
325 // executing the varying versions of them in the all-off mask case.
326 *okPtr = false;
327 return false;
328 }
329
330 if (llvm::dyn_cast<ForeachStmt>(node) != NULL || llvm::dyn_cast<ForeachActiveStmt>(node) != NULL ||
331 llvm::dyn_cast<ForeachUniqueStmt>(node) != NULL || llvm::dyn_cast<UnmaskedStmt>(node) != NULL) {
332 // The various foreach statements also shouldn't be run with an
333 // all-off mask. Since they can re-establish an 'all on' mask,
334 // this would be pretty unintuitive. (More generally, it's
335 // possibly a little strange to allow foreach in the presence of
336 // any non-uniform control flow...)
337 //
338 // Similarly, the implementation of foreach_unique assumes as a
339 // precondition that the mask won't be all off going into it, so
340 // we'll enforce that here...
341 *okPtr = false;
342 return false;
343 }
344
345 BinaryExpr *binaryExpr;
346 if ((binaryExpr = llvm::dyn_cast<BinaryExpr>(node)) != NULL) {
347 if (binaryExpr->op == BinaryExpr::Mod || binaryExpr->op == BinaryExpr::Div) {
348 *okPtr = false;
349 return false;
350 }
351 }
352 IndexExpr *ie;
353 if ((ie = llvm::dyn_cast<IndexExpr>(node)) != NULL && ie->baseExpr != NULL) {
354 const Type *type = ie->baseExpr->GetType();
355 if (type == NULL)
356 return true;
357 if (CastType<ReferenceType>(type) != NULL)
358 type = type->GetReferenceTarget();
359
360 ConstExpr *ce = llvm::dyn_cast<ConstExpr>(ie->index);
361 if (ce == NULL) {
362 // indexing with a variable... -> not safe
363 *okPtr = false;
364 return false;
365 }
366
367 const PointerType *pointerType = CastType<PointerType>(type);
368 if (pointerType != NULL) {
369 // pointer[index] -> can't be sure -> not safe
370 *okPtr = false;
371 return false;
372 }
373
374 const SequentialType *seqType = CastType<SequentialType>(type);
375 Assert(seqType != NULL);
376 int nElements = seqType->GetElementCount();
377 if (nElements == 0) {
378 // Unsized array, so we can't be sure -> not safe
379 *okPtr = false;
380 return false;
381 }
382
383 int32_t indices[ISPC_MAX_NVEC];
384 int count = ce->GetValues(indices);
385 for (int i = 0; i < count; ++i) {
386 if (indices[i] < 0 || indices[i] >= nElements) {
387 // Index is out of bounds -> not safe
388 *okPtr = false;
389 return false;
390 }
391 }
392
393 // All indices are in-bounds
394 return true;
395 }
396
397 MemberExpr *me;
398 if ((me = llvm::dyn_cast<MemberExpr>(node)) != NULL && me->dereferenceExpr) {
399 *okPtr = false;
400 return false;
401 }
402
403 if (llvm::dyn_cast<PtrDerefExpr>(node) != NULL) {
404 *okPtr = false;
405 return false;
406 }
407
408 /*
409 Don't allow turning if/else to straight-line-code if we
410 assign to a uniform.
411 */
412 AssignExpr *ae;
413 if ((ae = llvm::dyn_cast<AssignExpr>(node)) != NULL) {
414 if (ae->GetType()) {
415 if (ae->GetType()->IsUniformType()) {
416 *okPtr = false;
417 return false;
418 }
419 }
420 }
421
422 return true;
423 }
424
SafeToRunWithMaskAllOff(ASTNode * root)425 bool ispc::SafeToRunWithMaskAllOff(ASTNode *root) {
426 bool safe = true;
427 WalkAST(root, lCheckAllOffSafety, NULL, &safe);
428 return safe;
429 }
430