1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 // Analysis of the AST needed for HLSL generation
8 
9 #include "compiler/translator/ASTMetadataHLSL.h"
10 
11 #include "compiler/translator/CallDAG.h"
12 #include "compiler/translator/SymbolTable.h"
13 #include "compiler/translator/tree_util/IntermTraverse.h"
14 
15 namespace sh
16 {
17 
18 namespace
19 {
20 
21 // Class used to traverse the AST of a function definition, checking if the
22 // function uses a gradient, and writing the set of control flow using gradients.
23 // It assumes that the analysis has already been made for the function's
24 // callees.
25 class PullGradient : public TIntermTraverser
26 {
27   public:
PullGradient(MetadataList * metadataList,size_t index,const CallDAG & dag)28     PullGradient(MetadataList *metadataList, size_t index, const CallDAG &dag)
29         : TIntermTraverser(true, false, true),
30           mMetadataList(metadataList),
31           mMetadata(&(*metadataList)[index]),
32           mIndex(index),
33           mDag(dag)
34     {
35         ASSERT(index < metadataList->size());
36 
37         // ESSL 100 builtin gradient functions
38         mGradientBuiltinFunctions.insert(ImmutableString("texture2D"));
39         mGradientBuiltinFunctions.insert(ImmutableString("texture2DProj"));
40         mGradientBuiltinFunctions.insert(ImmutableString("textureCube"));
41 
42         // ESSL 300 builtin gradient functions
43         mGradientBuiltinFunctions.insert(ImmutableString("texture"));
44         mGradientBuiltinFunctions.insert(ImmutableString("textureProj"));
45         mGradientBuiltinFunctions.insert(ImmutableString("textureOffset"));
46         mGradientBuiltinFunctions.insert(ImmutableString("textureProjOffset"));
47 
48         // ESSL 310 doesn't add builtin gradient functions
49     }
50 
traverse(TIntermFunctionDefinition * node)51     void traverse(TIntermFunctionDefinition *node)
52     {
53         node->traverse(this);
54         ASSERT(mParents.empty());
55     }
56 
57     // Called when a gradient operation or a call to a function using a gradient is found.
onGradient()58     void onGradient()
59     {
60         mMetadata->mUsesGradient = true;
61         // Mark the latest control flow as using a gradient.
62         if (!mParents.empty())
63         {
64             mMetadata->mControlFlowsContainingGradient.insert(mParents.back());
65         }
66     }
67 
visitControlFlow(Visit visit,TIntermNode * node)68     void visitControlFlow(Visit visit, TIntermNode *node)
69     {
70         if (visit == PreVisit)
71         {
72             mParents.push_back(node);
73         }
74         else if (visit == PostVisit)
75         {
76             ASSERT(mParents.back() == node);
77             mParents.pop_back();
78             // A control flow's using a gradient means its parents are too.
79             if (mMetadata->mControlFlowsContainingGradient.count(node) > 0 && !mParents.empty())
80             {
81                 mMetadata->mControlFlowsContainingGradient.insert(mParents.back());
82             }
83         }
84     }
85 
visitLoop(Visit visit,TIntermLoop * loop)86     bool visitLoop(Visit visit, TIntermLoop *loop) override
87     {
88         visitControlFlow(visit, loop);
89         return true;
90     }
91 
visitIfElse(Visit visit,TIntermIfElse * ifElse)92     bool visitIfElse(Visit visit, TIntermIfElse *ifElse) override
93     {
94         visitControlFlow(visit, ifElse);
95         return true;
96     }
97 
visitUnary(Visit visit,TIntermUnary * node)98     bool visitUnary(Visit visit, TIntermUnary *node) override
99     {
100         if (visit == PreVisit)
101         {
102             switch (node->getOp())
103             {
104                 case EOpDFdx:
105                 case EOpDFdy:
106                 case EOpFwidth:
107                     onGradient();
108                     break;
109                 default:
110                     break;
111             }
112         }
113 
114         return true;
115     }
116 
visitAggregate(Visit visit,TIntermAggregate * node)117     bool visitAggregate(Visit visit, TIntermAggregate *node) override
118     {
119         if (visit == PreVisit)
120         {
121             if (node->getOp() == EOpCallFunctionInAST)
122             {
123                 size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId());
124                 ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
125 
126                 if ((*mMetadataList)[calleeIndex].mUsesGradient)
127                 {
128                     onGradient();
129                 }
130             }
131             else if (node->getOp() == EOpCallBuiltInFunction)
132             {
133                 if (mGradientBuiltinFunctions.find(node->getFunction()->name()) !=
134                     mGradientBuiltinFunctions.end())
135                 {
136                     onGradient();
137                 }
138             }
139         }
140 
141         return true;
142     }
143 
144   private:
145     MetadataList *mMetadataList;
146     ASTMetadataHLSL *mMetadata;
147     size_t mIndex;
148     const CallDAG &mDag;
149 
150     // Contains a stack of the control flow nodes that are parents of the node being
151     // currently visited. It is used to mark control flows using a gradient.
152     std::vector<TIntermNode *> mParents;
153 
154     // A list of builtin functions that use gradients
155     std::set<ImmutableString> mGradientBuiltinFunctions;
156 };
157 
158 // Traverses the AST of a function definition to compute the the discontinuous loops
159 // and the if statements containing gradient loops. It assumes that the gradient loops
160 // (loops that contain a gradient) have already been computed and that it has already
161 // traversed the current function's callees.
162 class PullComputeDiscontinuousAndGradientLoops : public TIntermTraverser
163 {
164   public:
PullComputeDiscontinuousAndGradientLoops(MetadataList * metadataList,size_t index,const CallDAG & dag)165     PullComputeDiscontinuousAndGradientLoops(MetadataList *metadataList,
166                                              size_t index,
167                                              const CallDAG &dag)
168         : TIntermTraverser(true, false, true),
169           mMetadataList(metadataList),
170           mMetadata(&(*metadataList)[index]),
171           mIndex(index),
172           mDag(dag)
173     {}
174 
traverse(TIntermFunctionDefinition * node)175     void traverse(TIntermFunctionDefinition *node)
176     {
177         node->traverse(this);
178         ASSERT(mLoopsAndSwitches.empty());
179         ASSERT(mIfs.empty());
180     }
181 
182     // Called when traversing a gradient loop or a call to a function with a
183     // gradient loop in its call graph.
onGradientLoop()184     void onGradientLoop()
185     {
186         mMetadata->mHasGradientLoopInCallGraph = true;
187         // Mark the latest if as using a discontinuous loop.
188         if (!mIfs.empty())
189         {
190             mMetadata->mIfsContainingGradientLoop.insert(mIfs.back());
191         }
192     }
193 
visitLoop(Visit visit,TIntermLoop * loop)194     bool visitLoop(Visit visit, TIntermLoop *loop) override
195     {
196         if (visit == PreVisit)
197         {
198             mLoopsAndSwitches.push_back(loop);
199 
200             if (mMetadata->hasGradientInCallGraph(loop))
201             {
202                 onGradientLoop();
203             }
204         }
205         else if (visit == PostVisit)
206         {
207             ASSERT(mLoopsAndSwitches.back() == loop);
208             mLoopsAndSwitches.pop_back();
209         }
210 
211         return true;
212     }
213 
visitIfElse(Visit visit,TIntermIfElse * node)214     bool visitIfElse(Visit visit, TIntermIfElse *node) override
215     {
216         if (visit == PreVisit)
217         {
218             mIfs.push_back(node);
219         }
220         else if (visit == PostVisit)
221         {
222             ASSERT(mIfs.back() == node);
223             mIfs.pop_back();
224             // An if using a discontinuous loop means its parents ifs are also discontinuous.
225             if (mMetadata->mIfsContainingGradientLoop.count(node) > 0 && !mIfs.empty())
226             {
227                 mMetadata->mIfsContainingGradientLoop.insert(mIfs.back());
228             }
229         }
230 
231         return true;
232     }
233 
visitBranch(Visit visit,TIntermBranch * node)234     bool visitBranch(Visit visit, TIntermBranch *node) override
235     {
236         if (visit == PreVisit)
237         {
238             switch (node->getFlowOp())
239             {
240                 case EOpBreak:
241                 {
242                     ASSERT(!mLoopsAndSwitches.empty());
243                     TIntermLoop *loop = mLoopsAndSwitches.back()->getAsLoopNode();
244                     if (loop != nullptr)
245                     {
246                         mMetadata->mDiscontinuousLoops.insert(loop);
247                     }
248                 }
249                 break;
250                 case EOpContinue:
251                 {
252                     ASSERT(!mLoopsAndSwitches.empty());
253                     TIntermLoop *loop = nullptr;
254                     size_t i          = mLoopsAndSwitches.size();
255                     while (loop == nullptr && i > 0)
256                     {
257                         --i;
258                         loop = mLoopsAndSwitches.at(i)->getAsLoopNode();
259                     }
260                     ASSERT(loop != nullptr);
261                     mMetadata->mDiscontinuousLoops.insert(loop);
262                 }
263                 break;
264                 case EOpKill:
265                 case EOpReturn:
266                     // A return or discard jumps out of all the enclosing loops
267                     if (!mLoopsAndSwitches.empty())
268                     {
269                         for (TIntermNode *intermNode : mLoopsAndSwitches)
270                         {
271                             TIntermLoop *loop = intermNode->getAsLoopNode();
272                             if (loop)
273                             {
274                                 mMetadata->mDiscontinuousLoops.insert(loop);
275                             }
276                         }
277                     }
278                     break;
279                 default:
280                     UNREACHABLE();
281             }
282         }
283 
284         return true;
285     }
286 
visitAggregate(Visit visit,TIntermAggregate * node)287     bool visitAggregate(Visit visit, TIntermAggregate *node) override
288     {
289         if (visit == PreVisit && node->getOp() == EOpCallFunctionInAST)
290         {
291             size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId());
292             ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
293 
294             if ((*mMetadataList)[calleeIndex].mHasGradientLoopInCallGraph)
295             {
296                 onGradientLoop();
297             }
298         }
299 
300         return true;
301     }
302 
visitSwitch(Visit visit,TIntermSwitch * node)303     bool visitSwitch(Visit visit, TIntermSwitch *node) override
304     {
305         if (visit == PreVisit)
306         {
307             mLoopsAndSwitches.push_back(node);
308         }
309         else if (visit == PostVisit)
310         {
311             ASSERT(mLoopsAndSwitches.back() == node);
312             mLoopsAndSwitches.pop_back();
313         }
314         return true;
315     }
316 
317   private:
318     MetadataList *mMetadataList;
319     ASTMetadataHLSL *mMetadata;
320     size_t mIndex;
321     const CallDAG &mDag;
322 
323     std::vector<TIntermNode *> mLoopsAndSwitches;
324     std::vector<TIntermIfElse *> mIfs;
325 };
326 
327 // Tags all the functions called in a discontinuous loop
328 class PushDiscontinuousLoops : public TIntermTraverser
329 {
330   public:
PushDiscontinuousLoops(MetadataList * metadataList,size_t index,const CallDAG & dag)331     PushDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag)
332         : TIntermTraverser(true, true, true),
333           mMetadataList(metadataList),
334           mMetadata(&(*metadataList)[index]),
335           mIndex(index),
336           mDag(dag),
337           mNestedDiscont(mMetadata->mCalledInDiscontinuousLoop ? 1 : 0)
338     {}
339 
traverse(TIntermFunctionDefinition * node)340     void traverse(TIntermFunctionDefinition *node)
341     {
342         node->traverse(this);
343         ASSERT(mNestedDiscont == (mMetadata->mCalledInDiscontinuousLoop ? 1 : 0));
344     }
345 
visitLoop(Visit visit,TIntermLoop * loop)346     bool visitLoop(Visit visit, TIntermLoop *loop) override
347     {
348         bool isDiscontinuous = mMetadata->mDiscontinuousLoops.count(loop) > 0;
349 
350         if (visit == PreVisit && isDiscontinuous)
351         {
352             mNestedDiscont++;
353         }
354         else if (visit == PostVisit && isDiscontinuous)
355         {
356             mNestedDiscont--;
357         }
358 
359         return true;
360     }
361 
visitAggregate(Visit visit,TIntermAggregate * node)362     bool visitAggregate(Visit visit, TIntermAggregate *node) override
363     {
364         switch (node->getOp())
365         {
366             case EOpCallFunctionInAST:
367                 if (visit == PreVisit && mNestedDiscont > 0)
368                 {
369                     size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId());
370                     ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
371 
372                     (*mMetadataList)[calleeIndex].mCalledInDiscontinuousLoop = true;
373                 }
374                 break;
375             default:
376                 break;
377         }
378         return true;
379     }
380 
381   private:
382     MetadataList *mMetadataList;
383     ASTMetadataHLSL *mMetadata;
384     size_t mIndex;
385     const CallDAG &mDag;
386 
387     int mNestedDiscont;
388 };
389 }  // namespace
390 
hasGradientInCallGraph(TIntermLoop * node)391 bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermLoop *node)
392 {
393     return mControlFlowsContainingGradient.count(node) > 0;
394 }
395 
hasGradientLoop(TIntermIfElse * node)396 bool ASTMetadataHLSL::hasGradientLoop(TIntermIfElse *node)
397 {
398     return mIfsContainingGradientLoop.count(node) > 0;
399 }
400 
CreateASTMetadataHLSL(TIntermNode * root,const CallDAG & callDag)401 MetadataList CreateASTMetadataHLSL(TIntermNode *root, const CallDAG &callDag)
402 {
403     MetadataList metadataList(callDag.size());
404 
405     // Compute all the information related to when gradient operations are used.
406     // We want to know for each function and control flow operation if they have
407     // a gradient operation in their call graph (shortened to "using a gradient"
408     // in the rest of the file).
409     //
410     // This computation is logically split in three steps:
411     //  1 - For each function compute if it uses a gradient in its body, ignoring
412     // calls to other user-defined functions.
413     //  2 - For each function determine if it uses a gradient in its call graph,
414     // using the result of step 1 and the CallDAG to know its callees.
415     //  3 - For each control flow statement of each function, check if it uses a
416     // gradient in the function's body, or if it calls a user-defined function that
417     // uses a gradient.
418     //
419     // We take advantage of the call graph being a DAG and instead compute 1, 2 and 3
420     // for leaves first, then going down the tree. This is correct because 1 doesn't
421     // depend on other functions, and 2 and 3 depend only on callees.
422     for (size_t i = 0; i < callDag.size(); i++)
423     {
424         PullGradient pull(&metadataList, i, callDag);
425         pull.traverse(callDag.getRecordFromIndex(i).node);
426     }
427 
428     // Compute which loops are discontinuous and which function are called in
429     // these loops. The same way computing gradient usage is a "pull" process,
430     // computing "bing used in a discont. loop" is a push process. However we also
431     // need to know what ifs have a discontinuous loop inside so we do the same type
432     // of callgraph analysis as for the gradient.
433 
434     // First compute which loops are discontinuous (no specific order) and pull
435     // the ifs and functions using a gradient loop.
436     for (size_t i = 0; i < callDag.size(); i++)
437     {
438         PullComputeDiscontinuousAndGradientLoops pull(&metadataList, i, callDag);
439         pull.traverse(callDag.getRecordFromIndex(i).node);
440     }
441 
442     // Then push the information to callees, either from the a local discontinuous
443     // loop or from the caller being called in a discontinuous loop already
444     for (size_t i = callDag.size(); i-- > 0;)
445     {
446         PushDiscontinuousLoops push(&metadataList, i, callDag);
447         push.traverse(callDag.getRecordFromIndex(i).node);
448     }
449 
450     // We create "Lod0" version of functions with the gradient operations replaced
451     // by non-gradient operations so that the D3D compiler is happier with discont
452     // loops.
453     for (auto &metadata : metadataList)
454     {
455         metadata.mNeedsLod0 = metadata.mCalledInDiscontinuousLoop && metadata.mUsesGradient;
456     }
457 
458     return metadataList;
459 }
460 
461 }  // namespace sh
462