1 //
2 // Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 // Analysis of the AST needed for HLSL generation
8 
9 #include "compiler/translator/ASTMetadataHLSL.h"
10 
11 #include "compiler/translator/CallDAG.h"
12 #include "compiler/translator/IntermTraverse.h"
13 #include "compiler/translator/SymbolTable.h"
14 
15 namespace sh
16 {
17 
18 namespace
19 {
20 
21 // Class used to traverse the AST of a function definition, checking if the
22 // function uses a gradient, and writing the set of control flow using gradients.
23 // It assumes that the analysis has already been made for the function's
24 // callees.
25 class PullGradient : public TIntermTraverser
26 {
27   public:
PullGradient(MetadataList * metadataList,size_t index,const CallDAG & dag)28     PullGradient(MetadataList *metadataList, size_t index, const CallDAG &dag)
29         : TIntermTraverser(true, false, true),
30           mMetadataList(metadataList),
31           mMetadata(&(*metadataList)[index]),
32           mIndex(index),
33           mDag(dag)
34     {
35         ASSERT(index < metadataList->size());
36 
37         // ESSL 100 builtin gradient functions
38         mGradientBuiltinFunctions.insert("texture2D");
39         mGradientBuiltinFunctions.insert("texture2DProj");
40         mGradientBuiltinFunctions.insert("textureCube");
41 
42         // ESSL 300 builtin gradient functions
43         mGradientBuiltinFunctions.insert("texture");
44         mGradientBuiltinFunctions.insert("textureProj");
45         mGradientBuiltinFunctions.insert("textureOffset");
46         mGradientBuiltinFunctions.insert("textureProjOffset");
47 
48         // ESSL 310 doesn't add builtin gradient functions
49     }
50 
traverse(TIntermFunctionDefinition * node)51     void traverse(TIntermFunctionDefinition *node)
52     {
53         node->traverse(this);
54         ASSERT(mParents.empty());
55     }
56 
57     // Called when a gradient operation or a call to a function using a gradient is found.
onGradient()58     void onGradient()
59     {
60         mMetadata->mUsesGradient = true;
61         // Mark the latest control flow as using a gradient.
62         if (!mParents.empty())
63         {
64             mMetadata->mControlFlowsContainingGradient.insert(mParents.back());
65         }
66     }
67 
visitControlFlow(Visit visit,TIntermNode * node)68     void visitControlFlow(Visit visit, TIntermNode *node)
69     {
70         if (visit == PreVisit)
71         {
72             mParents.push_back(node);
73         }
74         else if (visit == PostVisit)
75         {
76             ASSERT(mParents.back() == node);
77             mParents.pop_back();
78             // A control flow's using a gradient means its parents are too.
79             if (mMetadata->mControlFlowsContainingGradient.count(node) > 0 && !mParents.empty())
80             {
81                 mMetadata->mControlFlowsContainingGradient.insert(mParents.back());
82             }
83         }
84     }
85 
visitLoop(Visit visit,TIntermLoop * loop)86     bool visitLoop(Visit visit, TIntermLoop *loop) override
87     {
88         visitControlFlow(visit, loop);
89         return true;
90     }
91 
visitIfElse(Visit visit,TIntermIfElse * ifElse)92     bool visitIfElse(Visit visit, TIntermIfElse *ifElse) override
93     {
94         visitControlFlow(visit, ifElse);
95         return true;
96     }
97 
visitUnary(Visit visit,TIntermUnary * node)98     bool visitUnary(Visit visit, TIntermUnary *node) override
99     {
100         if (visit == PreVisit)
101         {
102             switch (node->getOp())
103             {
104                 case EOpDFdx:
105                 case EOpDFdy:
106                 case EOpFwidth:
107                     onGradient();
108                 default:
109                     break;
110             }
111         }
112 
113         return true;
114     }
115 
visitAggregate(Visit visit,TIntermAggregate * node)116     bool visitAggregate(Visit visit, TIntermAggregate *node) override
117     {
118         if (visit == PreVisit)
119         {
120             if (node->getOp() == EOpCallFunctionInAST)
121             {
122                 size_t calleeIndex = mDag.findIndex(node->getFunctionSymbolInfo());
123                 ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
124 
125                 if ((*mMetadataList)[calleeIndex].mUsesGradient)
126                 {
127                     onGradient();
128                 }
129             }
130             else if (node->getOp() == EOpCallBuiltInFunction)
131             {
132                 if (mGradientBuiltinFunctions.find(node->getFunctionSymbolInfo()->getName()) !=
133                     mGradientBuiltinFunctions.end())
134                 {
135                     onGradient();
136                 }
137             }
138         }
139 
140         return true;
141     }
142 
143   private:
144     MetadataList *mMetadataList;
145     ASTMetadataHLSL *mMetadata;
146     size_t mIndex;
147     const CallDAG &mDag;
148 
149     // Contains a stack of the control flow nodes that are parents of the node being
150     // currently visited. It is used to mark control flows using a gradient.
151     std::vector<TIntermNode *> mParents;
152 
153     // A list of builtin functions that use gradients
154     std::set<TString> mGradientBuiltinFunctions;
155 };
156 
157 // Traverses the AST of a function definition to compute the the discontinuous loops
158 // and the if statements containing gradient loops. It assumes that the gradient loops
159 // (loops that contain a gradient) have already been computed and that it has already
160 // traversed the current function's callees.
161 class PullComputeDiscontinuousAndGradientLoops : public TIntermTraverser
162 {
163   public:
PullComputeDiscontinuousAndGradientLoops(MetadataList * metadataList,size_t index,const CallDAG & dag)164     PullComputeDiscontinuousAndGradientLoops(MetadataList *metadataList,
165                                              size_t index,
166                                              const CallDAG &dag)
167         : TIntermTraverser(true, false, true),
168           mMetadataList(metadataList),
169           mMetadata(&(*metadataList)[index]),
170           mIndex(index),
171           mDag(dag)
172     {
173     }
174 
traverse(TIntermFunctionDefinition * node)175     void traverse(TIntermFunctionDefinition *node)
176     {
177         node->traverse(this);
178         ASSERT(mLoopsAndSwitches.empty());
179         ASSERT(mIfs.empty());
180     }
181 
182     // Called when traversing a gradient loop or a call to a function with a
183     // gradient loop in its call graph.
onGradientLoop()184     void onGradientLoop()
185     {
186         mMetadata->mHasGradientLoopInCallGraph = true;
187         // Mark the latest if as using a discontinuous loop.
188         if (!mIfs.empty())
189         {
190             mMetadata->mIfsContainingGradientLoop.insert(mIfs.back());
191         }
192     }
193 
visitLoop(Visit visit,TIntermLoop * loop)194     bool visitLoop(Visit visit, TIntermLoop *loop) override
195     {
196         if (visit == PreVisit)
197         {
198             mLoopsAndSwitches.push_back(loop);
199 
200             if (mMetadata->hasGradientInCallGraph(loop))
201             {
202                 onGradientLoop();
203             }
204         }
205         else if (visit == PostVisit)
206         {
207             ASSERT(mLoopsAndSwitches.back() == loop);
208             mLoopsAndSwitches.pop_back();
209         }
210 
211         return true;
212     }
213 
visitIfElse(Visit visit,TIntermIfElse * node)214     bool visitIfElse(Visit visit, TIntermIfElse *node) override
215     {
216         if (visit == PreVisit)
217         {
218             mIfs.push_back(node);
219         }
220         else if (visit == PostVisit)
221         {
222             ASSERT(mIfs.back() == node);
223             mIfs.pop_back();
224             // An if using a discontinuous loop means its parents ifs are also discontinuous.
225             if (mMetadata->mIfsContainingGradientLoop.count(node) > 0 && !mIfs.empty())
226             {
227                 mMetadata->mIfsContainingGradientLoop.insert(mIfs.back());
228             }
229         }
230 
231         return true;
232     }
233 
visitBranch(Visit visit,TIntermBranch * node)234     bool visitBranch(Visit visit, TIntermBranch *node) override
235     {
236         if (visit == PreVisit)
237         {
238             switch (node->getFlowOp())
239             {
240                 case EOpBreak:
241                 {
242                     ASSERT(!mLoopsAndSwitches.empty());
243                     TIntermLoop *loop = mLoopsAndSwitches.back()->getAsLoopNode();
244                     if (loop != nullptr)
245                     {
246                         mMetadata->mDiscontinuousLoops.insert(loop);
247                     }
248                 }
249                 break;
250                 case EOpContinue:
251                 {
252                     ASSERT(!mLoopsAndSwitches.empty());
253                     TIntermLoop *loop = nullptr;
254                     size_t i          = mLoopsAndSwitches.size();
255                     while (loop == nullptr && i > 0)
256                     {
257                         --i;
258                         loop = mLoopsAndSwitches.at(i)->getAsLoopNode();
259                     }
260                     ASSERT(loop != nullptr);
261                     mMetadata->mDiscontinuousLoops.insert(loop);
262                 }
263                 break;
264                 case EOpKill:
265                 case EOpReturn:
266                     // A return or discard jumps out of all the enclosing loops
267                     if (!mLoopsAndSwitches.empty())
268                     {
269                         for (TIntermNode *intermNode : mLoopsAndSwitches)
270                         {
271                             TIntermLoop *loop = intermNode->getAsLoopNode();
272                             if (loop)
273                             {
274                                 mMetadata->mDiscontinuousLoops.insert(loop);
275                             }
276                         }
277                     }
278                     break;
279                 default:
280                     UNREACHABLE();
281             }
282         }
283 
284         return true;
285     }
286 
visitAggregate(Visit visit,TIntermAggregate * node)287     bool visitAggregate(Visit visit, TIntermAggregate *node) override
288     {
289         if (visit == PreVisit && node->getOp() == EOpCallFunctionInAST)
290         {
291             size_t calleeIndex = mDag.findIndex(node->getFunctionSymbolInfo());
292             ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
293 
294             if ((*mMetadataList)[calleeIndex].mHasGradientLoopInCallGraph)
295             {
296                 onGradientLoop();
297             }
298         }
299 
300         return true;
301     }
302 
visitSwitch(Visit visit,TIntermSwitch * node)303     bool visitSwitch(Visit visit, TIntermSwitch *node) override
304     {
305         if (visit == PreVisit)
306         {
307             mLoopsAndSwitches.push_back(node);
308         }
309         else if (visit == PostVisit)
310         {
311             ASSERT(mLoopsAndSwitches.back() == node);
312             mLoopsAndSwitches.pop_back();
313         }
314         return true;
315     }
316 
317   private:
318     MetadataList *mMetadataList;
319     ASTMetadataHLSL *mMetadata;
320     size_t mIndex;
321     const CallDAG &mDag;
322 
323     std::vector<TIntermNode *> mLoopsAndSwitches;
324     std::vector<TIntermIfElse *> mIfs;
325 };
326 
327 // Tags all the functions called in a discontinuous loop
328 class PushDiscontinuousLoops : public TIntermTraverser
329 {
330   public:
PushDiscontinuousLoops(MetadataList * metadataList,size_t index,const CallDAG & dag)331     PushDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag)
332         : TIntermTraverser(true, true, true),
333           mMetadataList(metadataList),
334           mMetadata(&(*metadataList)[index]),
335           mIndex(index),
336           mDag(dag),
337           mNestedDiscont(mMetadata->mCalledInDiscontinuousLoop ? 1 : 0)
338     {
339     }
340 
traverse(TIntermFunctionDefinition * node)341     void traverse(TIntermFunctionDefinition *node)
342     {
343         node->traverse(this);
344         ASSERT(mNestedDiscont == (mMetadata->mCalledInDiscontinuousLoop ? 1 : 0));
345     }
346 
visitLoop(Visit visit,TIntermLoop * loop)347     bool visitLoop(Visit visit, TIntermLoop *loop) override
348     {
349         bool isDiscontinuous = mMetadata->mDiscontinuousLoops.count(loop) > 0;
350 
351         if (visit == PreVisit && isDiscontinuous)
352         {
353             mNestedDiscont++;
354         }
355         else if (visit == PostVisit && isDiscontinuous)
356         {
357             mNestedDiscont--;
358         }
359 
360         return true;
361     }
362 
visitAggregate(Visit visit,TIntermAggregate * node)363     bool visitAggregate(Visit visit, TIntermAggregate *node) override
364     {
365         switch (node->getOp())
366         {
367             case EOpCallFunctionInAST:
368                 if (visit == PreVisit && mNestedDiscont > 0)
369                 {
370                     size_t calleeIndex = mDag.findIndex(node->getFunctionSymbolInfo());
371                     ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
372 
373                     (*mMetadataList)[calleeIndex].mCalledInDiscontinuousLoop = true;
374                 }
375                 break;
376             default:
377                 break;
378         }
379         return true;
380     }
381 
382   private:
383     MetadataList *mMetadataList;
384     ASTMetadataHLSL *mMetadata;
385     size_t mIndex;
386     const CallDAG &mDag;
387 
388     int mNestedDiscont;
389 };
390 }
391 
hasGradientInCallGraph(TIntermLoop * node)392 bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermLoop *node)
393 {
394     return mControlFlowsContainingGradient.count(node) > 0;
395 }
396 
hasGradientLoop(TIntermIfElse * node)397 bool ASTMetadataHLSL::hasGradientLoop(TIntermIfElse *node)
398 {
399     return mIfsContainingGradientLoop.count(node) > 0;
400 }
401 
CreateASTMetadataHLSL(TIntermNode * root,const CallDAG & callDag)402 MetadataList CreateASTMetadataHLSL(TIntermNode *root, const CallDAG &callDag)
403 {
404     MetadataList metadataList(callDag.size());
405 
406     // Compute all the information related to when gradient operations are used.
407     // We want to know for each function and control flow operation if they have
408     // a gradient operation in their call graph (shortened to "using a gradient"
409     // in the rest of the file).
410     //
411     // This computation is logically split in three steps:
412     //  1 - For each function compute if it uses a gradient in its body, ignoring
413     // calls to other user-defined functions.
414     //  2 - For each function determine if it uses a gradient in its call graph,
415     // using the result of step 1 and the CallDAG to know its callees.
416     //  3 - For each control flow statement of each function, check if it uses a
417     // gradient in the function's body, or if it calls a user-defined function that
418     // uses a gradient.
419     //
420     // We take advantage of the call graph being a DAG and instead compute 1, 2 and 3
421     // for leaves first, then going down the tree. This is correct because 1 doesn't
422     // depend on other functions, and 2 and 3 depend only on callees.
423     for (size_t i = 0; i < callDag.size(); i++)
424     {
425         PullGradient pull(&metadataList, i, callDag);
426         pull.traverse(callDag.getRecordFromIndex(i).node);
427     }
428 
429     // Compute which loops are discontinuous and which function are called in
430     // these loops. The same way computing gradient usage is a "pull" process,
431     // computing "bing used in a discont. loop" is a push process. However we also
432     // need to know what ifs have a discontinuous loop inside so we do the same type
433     // of callgraph analysis as for the gradient.
434 
435     // First compute which loops are discontinuous (no specific order) and pull
436     // the ifs and functions using a gradient loop.
437     for (size_t i = 0; i < callDag.size(); i++)
438     {
439         PullComputeDiscontinuousAndGradientLoops pull(&metadataList, i, callDag);
440         pull.traverse(callDag.getRecordFromIndex(i).node);
441     }
442 
443     // Then push the information to callees, either from the a local discontinuous
444     // loop or from the caller being called in a discontinuous loop already
445     for (size_t i = callDag.size(); i-- > 0;)
446     {
447         PushDiscontinuousLoops push(&metadataList, i, callDag);
448         push.traverse(callDag.getRecordFromIndex(i).node);
449     }
450 
451     // We create "Lod0" version of functions with the gradient operations replaced
452     // by non-gradient operations so that the D3D compiler is happier with discont
453     // loops.
454     for (auto &metadata : metadataList)
455     {
456         metadata.mNeedsLod0 = metadata.mCalledInDiscontinuousLoop && metadata.mUsesGradient;
457     }
458 
459     return metadataList;
460 }
461 
462 }  // namespace sh
463