1 //
2 // Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 // Analysis of the AST needed for HLSL generation
8 
9 #include "compiler/translator/ASTMetadataHLSL.h"
10 
11 #include "compiler/translator/CallDAG.h"
12 #include "compiler/translator/IntermTraverse.h"
13 #include "compiler/translator/SymbolTable.h"
14 
15 namespace sh
16 {
17 
18 namespace
19 {
20 
21 // Class used to traverse the AST of a function definition, checking if the
22 // function uses a gradient, and writing the set of control flow using gradients.
23 // It assumes that the analysis has already been made for the function's
24 // callees.
25 class PullGradient : public TIntermTraverser
26 {
27   public:
PullGradient(MetadataList * metadataList,size_t index,const CallDAG & dag)28     PullGradient(MetadataList *metadataList, size_t index, const CallDAG &dag)
29         : TIntermTraverser(true, false, true),
30           mMetadataList(metadataList),
31           mMetadata(&(*metadataList)[index]),
32           mIndex(index),
33           mDag(dag)
34     {
35         ASSERT(index < metadataList->size());
36 
37         // ESSL 100 builtin gradient functions
38         mGradientBuiltinFunctions.insert(ImmutableString("texture2D"));
39         mGradientBuiltinFunctions.insert(ImmutableString("texture2DProj"));
40         mGradientBuiltinFunctions.insert(ImmutableString("textureCube"));
41 
42         // ESSL 300 builtin gradient functions
43         mGradientBuiltinFunctions.insert(ImmutableString("texture"));
44         mGradientBuiltinFunctions.insert(ImmutableString("textureProj"));
45         mGradientBuiltinFunctions.insert(ImmutableString("textureOffset"));
46         mGradientBuiltinFunctions.insert(ImmutableString("textureProjOffset"));
47 
48         // ESSL 310 doesn't add builtin gradient functions
49     }
50 
traverse(TIntermFunctionDefinition * node)51     void traverse(TIntermFunctionDefinition *node)
52     {
53         node->traverse(this);
54         ASSERT(mParents.empty());
55     }
56 
57     // Called when a gradient operation or a call to a function using a gradient is found.
onGradient()58     void onGradient()
59     {
60         mMetadata->mUsesGradient = true;
61         // Mark the latest control flow as using a gradient.
62         if (!mParents.empty())
63         {
64             mMetadata->mControlFlowsContainingGradient.insert(mParents.back());
65         }
66     }
67 
visitControlFlow(Visit visit,TIntermNode * node)68     void visitControlFlow(Visit visit, TIntermNode *node)
69     {
70         if (visit == PreVisit)
71         {
72             mParents.push_back(node);
73         }
74         else if (visit == PostVisit)
75         {
76             ASSERT(mParents.back() == node);
77             mParents.pop_back();
78             // A control flow's using a gradient means its parents are too.
79             if (mMetadata->mControlFlowsContainingGradient.count(node) > 0 && !mParents.empty())
80             {
81                 mMetadata->mControlFlowsContainingGradient.insert(mParents.back());
82             }
83         }
84     }
85 
visitLoop(Visit visit,TIntermLoop * loop)86     bool visitLoop(Visit visit, TIntermLoop *loop) override
87     {
88         visitControlFlow(visit, loop);
89         return true;
90     }
91 
visitIfElse(Visit visit,TIntermIfElse * ifElse)92     bool visitIfElse(Visit visit, TIntermIfElse *ifElse) override
93     {
94         visitControlFlow(visit, ifElse);
95         return true;
96     }
97 
visitUnary(Visit visit,TIntermUnary * node)98     bool visitUnary(Visit visit, TIntermUnary *node) override
99     {
100         if (visit == PreVisit)
101         {
102             switch (node->getOp())
103             {
104                 case EOpDFdx:
105                 case EOpDFdy:
106                 case EOpFwidth:
107                     onGradient();
108                     break;
109                 default:
110                     break;
111             }
112         }
113 
114         return true;
115     }
116 
visitAggregate(Visit visit,TIntermAggregate * node)117     bool visitAggregate(Visit visit, TIntermAggregate *node) override
118     {
119         if (visit == PreVisit)
120         {
121             if (node->getOp() == EOpCallFunctionInAST)
122             {
123                 size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId());
124                 ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
125 
126                 if ((*mMetadataList)[calleeIndex].mUsesGradient)
127                 {
128                     onGradient();
129                 }
130             }
131             else if (node->getOp() == EOpCallBuiltInFunction)
132             {
133                 if (mGradientBuiltinFunctions.find(node->getFunction()->name()) !=
134                     mGradientBuiltinFunctions.end())
135                 {
136                     onGradient();
137                 }
138             }
139         }
140 
141         return true;
142     }
143 
144   private:
145     MetadataList *mMetadataList;
146     ASTMetadataHLSL *mMetadata;
147     size_t mIndex;
148     const CallDAG &mDag;
149 
150     // Contains a stack of the control flow nodes that are parents of the node being
151     // currently visited. It is used to mark control flows using a gradient.
152     std::vector<TIntermNode *> mParents;
153 
154     // A list of builtin functions that use gradients
155     std::set<ImmutableString> mGradientBuiltinFunctions;
156 };
157 
158 // Traverses the AST of a function definition to compute the the discontinuous loops
159 // and the if statements containing gradient loops. It assumes that the gradient loops
160 // (loops that contain a gradient) have already been computed and that it has already
161 // traversed the current function's callees.
162 class PullComputeDiscontinuousAndGradientLoops : public TIntermTraverser
163 {
164   public:
PullComputeDiscontinuousAndGradientLoops(MetadataList * metadataList,size_t index,const CallDAG & dag)165     PullComputeDiscontinuousAndGradientLoops(MetadataList *metadataList,
166                                              size_t index,
167                                              const CallDAG &dag)
168         : TIntermTraverser(true, false, true),
169           mMetadataList(metadataList),
170           mMetadata(&(*metadataList)[index]),
171           mIndex(index),
172           mDag(dag)
173     {
174     }
175 
traverse(TIntermFunctionDefinition * node)176     void traverse(TIntermFunctionDefinition *node)
177     {
178         node->traverse(this);
179         ASSERT(mLoopsAndSwitches.empty());
180         ASSERT(mIfs.empty());
181     }
182 
183     // Called when traversing a gradient loop or a call to a function with a
184     // gradient loop in its call graph.
onGradientLoop()185     void onGradientLoop()
186     {
187         mMetadata->mHasGradientLoopInCallGraph = true;
188         // Mark the latest if as using a discontinuous loop.
189         if (!mIfs.empty())
190         {
191             mMetadata->mIfsContainingGradientLoop.insert(mIfs.back());
192         }
193     }
194 
visitLoop(Visit visit,TIntermLoop * loop)195     bool visitLoop(Visit visit, TIntermLoop *loop) override
196     {
197         if (visit == PreVisit)
198         {
199             mLoopsAndSwitches.push_back(loop);
200 
201             if (mMetadata->hasGradientInCallGraph(loop))
202             {
203                 onGradientLoop();
204             }
205         }
206         else if (visit == PostVisit)
207         {
208             ASSERT(mLoopsAndSwitches.back() == loop);
209             mLoopsAndSwitches.pop_back();
210         }
211 
212         return true;
213     }
214 
visitIfElse(Visit visit,TIntermIfElse * node)215     bool visitIfElse(Visit visit, TIntermIfElse *node) override
216     {
217         if (visit == PreVisit)
218         {
219             mIfs.push_back(node);
220         }
221         else if (visit == PostVisit)
222         {
223             ASSERT(mIfs.back() == node);
224             mIfs.pop_back();
225             // An if using a discontinuous loop means its parents ifs are also discontinuous.
226             if (mMetadata->mIfsContainingGradientLoop.count(node) > 0 && !mIfs.empty())
227             {
228                 mMetadata->mIfsContainingGradientLoop.insert(mIfs.back());
229             }
230         }
231 
232         return true;
233     }
234 
visitBranch(Visit visit,TIntermBranch * node)235     bool visitBranch(Visit visit, TIntermBranch *node) override
236     {
237         if (visit == PreVisit)
238         {
239             switch (node->getFlowOp())
240             {
241                 case EOpBreak:
242                 {
243                     ASSERT(!mLoopsAndSwitches.empty());
244                     TIntermLoop *loop = mLoopsAndSwitches.back()->getAsLoopNode();
245                     if (loop != nullptr)
246                     {
247                         mMetadata->mDiscontinuousLoops.insert(loop);
248                     }
249                 }
250                 break;
251                 case EOpContinue:
252                 {
253                     ASSERT(!mLoopsAndSwitches.empty());
254                     TIntermLoop *loop = nullptr;
255                     size_t i          = mLoopsAndSwitches.size();
256                     while (loop == nullptr && i > 0)
257                     {
258                         --i;
259                         loop = mLoopsAndSwitches.at(i)->getAsLoopNode();
260                     }
261                     ASSERT(loop != nullptr);
262                     mMetadata->mDiscontinuousLoops.insert(loop);
263                 }
264                 break;
265                 case EOpKill:
266                 case EOpReturn:
267                     // A return or discard jumps out of all the enclosing loops
268                     if (!mLoopsAndSwitches.empty())
269                     {
270                         for (TIntermNode *intermNode : mLoopsAndSwitches)
271                         {
272                             TIntermLoop *loop = intermNode->getAsLoopNode();
273                             if (loop)
274                             {
275                                 mMetadata->mDiscontinuousLoops.insert(loop);
276                             }
277                         }
278                     }
279                     break;
280                 default:
281                     UNREACHABLE();
282             }
283         }
284 
285         return true;
286     }
287 
visitAggregate(Visit visit,TIntermAggregate * node)288     bool visitAggregate(Visit visit, TIntermAggregate *node) override
289     {
290         if (visit == PreVisit && node->getOp() == EOpCallFunctionInAST)
291         {
292             size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId());
293             ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
294 
295             if ((*mMetadataList)[calleeIndex].mHasGradientLoopInCallGraph)
296             {
297                 onGradientLoop();
298             }
299         }
300 
301         return true;
302     }
303 
visitSwitch(Visit visit,TIntermSwitch * node)304     bool visitSwitch(Visit visit, TIntermSwitch *node) override
305     {
306         if (visit == PreVisit)
307         {
308             mLoopsAndSwitches.push_back(node);
309         }
310         else if (visit == PostVisit)
311         {
312             ASSERT(mLoopsAndSwitches.back() == node);
313             mLoopsAndSwitches.pop_back();
314         }
315         return true;
316     }
317 
318   private:
319     MetadataList *mMetadataList;
320     ASTMetadataHLSL *mMetadata;
321     size_t mIndex;
322     const CallDAG &mDag;
323 
324     std::vector<TIntermNode *> mLoopsAndSwitches;
325     std::vector<TIntermIfElse *> mIfs;
326 };
327 
328 // Tags all the functions called in a discontinuous loop
329 class PushDiscontinuousLoops : public TIntermTraverser
330 {
331   public:
PushDiscontinuousLoops(MetadataList * metadataList,size_t index,const CallDAG & dag)332     PushDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag)
333         : TIntermTraverser(true, true, true),
334           mMetadataList(metadataList),
335           mMetadata(&(*metadataList)[index]),
336           mIndex(index),
337           mDag(dag),
338           mNestedDiscont(mMetadata->mCalledInDiscontinuousLoop ? 1 : 0)
339     {
340     }
341 
traverse(TIntermFunctionDefinition * node)342     void traverse(TIntermFunctionDefinition *node)
343     {
344         node->traverse(this);
345         ASSERT(mNestedDiscont == (mMetadata->mCalledInDiscontinuousLoop ? 1 : 0));
346     }
347 
visitLoop(Visit visit,TIntermLoop * loop)348     bool visitLoop(Visit visit, TIntermLoop *loop) override
349     {
350         bool isDiscontinuous = mMetadata->mDiscontinuousLoops.count(loop) > 0;
351 
352         if (visit == PreVisit && isDiscontinuous)
353         {
354             mNestedDiscont++;
355         }
356         else if (visit == PostVisit && isDiscontinuous)
357         {
358             mNestedDiscont--;
359         }
360 
361         return true;
362     }
363 
visitAggregate(Visit visit,TIntermAggregate * node)364     bool visitAggregate(Visit visit, TIntermAggregate *node) override
365     {
366         switch (node->getOp())
367         {
368             case EOpCallFunctionInAST:
369                 if (visit == PreVisit && mNestedDiscont > 0)
370                 {
371                     size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId());
372                     ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
373 
374                     (*mMetadataList)[calleeIndex].mCalledInDiscontinuousLoop = true;
375                 }
376                 break;
377             default:
378                 break;
379         }
380         return true;
381     }
382 
383   private:
384     MetadataList *mMetadataList;
385     ASTMetadataHLSL *mMetadata;
386     size_t mIndex;
387     const CallDAG &mDag;
388 
389     int mNestedDiscont;
390 };
391 }
392 
hasGradientInCallGraph(TIntermLoop * node)393 bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermLoop *node)
394 {
395     return mControlFlowsContainingGradient.count(node) > 0;
396 }
397 
hasGradientLoop(TIntermIfElse * node)398 bool ASTMetadataHLSL::hasGradientLoop(TIntermIfElse *node)
399 {
400     return mIfsContainingGradientLoop.count(node) > 0;
401 }
402 
CreateASTMetadataHLSL(TIntermNode * root,const CallDAG & callDag)403 MetadataList CreateASTMetadataHLSL(TIntermNode *root, const CallDAG &callDag)
404 {
405     MetadataList metadataList(callDag.size());
406 
407     // Compute all the information related to when gradient operations are used.
408     // We want to know for each function and control flow operation if they have
409     // a gradient operation in their call graph (shortened to "using a gradient"
410     // in the rest of the file).
411     //
412     // This computation is logically split in three steps:
413     //  1 - For each function compute if it uses a gradient in its body, ignoring
414     // calls to other user-defined functions.
415     //  2 - For each function determine if it uses a gradient in its call graph,
416     // using the result of step 1 and the CallDAG to know its callees.
417     //  3 - For each control flow statement of each function, check if it uses a
418     // gradient in the function's body, or if it calls a user-defined function that
419     // uses a gradient.
420     //
421     // We take advantage of the call graph being a DAG and instead compute 1, 2 and 3
422     // for leaves first, then going down the tree. This is correct because 1 doesn't
423     // depend on other functions, and 2 and 3 depend only on callees.
424     for (size_t i = 0; i < callDag.size(); i++)
425     {
426         PullGradient pull(&metadataList, i, callDag);
427         pull.traverse(callDag.getRecordFromIndex(i).node);
428     }
429 
430     // Compute which loops are discontinuous and which function are called in
431     // these loops. The same way computing gradient usage is a "pull" process,
432     // computing "bing used in a discont. loop" is a push process. However we also
433     // need to know what ifs have a discontinuous loop inside so we do the same type
434     // of callgraph analysis as for the gradient.
435 
436     // First compute which loops are discontinuous (no specific order) and pull
437     // the ifs and functions using a gradient loop.
438     for (size_t i = 0; i < callDag.size(); i++)
439     {
440         PullComputeDiscontinuousAndGradientLoops pull(&metadataList, i, callDag);
441         pull.traverse(callDag.getRecordFromIndex(i).node);
442     }
443 
444     // Then push the information to callees, either from the a local discontinuous
445     // loop or from the caller being called in a discontinuous loop already
446     for (size_t i = callDag.size(); i-- > 0;)
447     {
448         PushDiscontinuousLoops push(&metadataList, i, callDag);
449         push.traverse(callDag.getRecordFromIndex(i).node);
450     }
451 
452     // We create "Lod0" version of functions with the gradient operations replaced
453     // by non-gradient operations so that the D3D compiler is happier with discont
454     // loops.
455     for (auto &metadata : metadataList)
456     {
457         metadata.mNeedsLod0 = metadata.mCalledInDiscontinuousLoop && metadata.mUsesGradient;
458     }
459 
460     return metadataList;
461 }
462 
463 }  // namespace sh
464