1 //
2 // Copyright (c) 2017 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RunAtTheEndOfShader.cpp: Add code to be run at the end of the shader. In case main() contains a
7 // return statement, this is done by replacing the main() function with another function that calls
8 // the old main, like this:
9 //
10 // void main() { body }
11 // =>
12 // void main0() { body }
13 // void main()
14 // {
15 //     main0();
16 //     codeToRun
17 // }
18 //
19 // This way the code will get run even if the return statement inside main is executed.
20 //
21 
22 #include "compiler/translator/RunAtTheEndOfShader.h"
23 
24 #include "compiler/translator/FindMain.h"
25 #include "compiler/translator/IntermNode.h"
26 #include "compiler/translator/IntermNode_util.h"
27 #include "compiler/translator/IntermTraverse.h"
28 #include "compiler/translator/SymbolTable.h"
29 
30 namespace sh
31 {
32 
33 namespace
34 {
35 
36 class ContainsReturnTraverser : public TIntermTraverser
37 {
38   public:
ContainsReturnTraverser()39     ContainsReturnTraverser() : TIntermTraverser(true, false, false), mContainsReturn(false) {}
40 
visitBranch(Visit visit,TIntermBranch * node)41     bool visitBranch(Visit visit, TIntermBranch *node) override
42     {
43         if (node->getFlowOp() == EOpReturn)
44         {
45             mContainsReturn = true;
46         }
47         return false;
48     }
49 
containsReturn()50     bool containsReturn() { return mContainsReturn; }
51 
52   private:
53     bool mContainsReturn;
54 };
55 
ContainsReturn(TIntermNode * node)56 bool ContainsReturn(TIntermNode *node)
57 {
58     ContainsReturnTraverser traverser;
59     node->traverse(&traverser);
60     return traverser.containsReturn();
61 }
62 
WrapMainAndAppend(TIntermBlock * root,TIntermFunctionDefinition * main,TIntermNode * codeToRun,TSymbolTable * symbolTable)63 void WrapMainAndAppend(TIntermBlock *root,
64                        TIntermFunctionDefinition *main,
65                        TIntermNode *codeToRun,
66                        TSymbolTable *symbolTable)
67 {
68     // Replace main() with main0() with the same body.
69     TSymbolUniqueId oldMainId(symbolTable);
70     std::stringstream oldMainName;
71     oldMainName << "main" << oldMainId.get();
72     TIntermFunctionDefinition *oldMain = CreateInternalFunctionDefinitionNode(
73         TType(EbtVoid), oldMainName.str().c_str(), main->getBody(), oldMainId);
74 
75     bool replaced = root->replaceChildNode(main, oldMain);
76     ASSERT(replaced);
77 
78     // void main()
79     TIntermFunctionPrototype *newMainProto = new TIntermFunctionPrototype(
80         TType(EbtVoid), main->getFunctionPrototype()->getFunctionSymbolInfo()->getId());
81     newMainProto->getFunctionSymbolInfo()->setName("main");
82 
83     // {
84     //     main0();
85     //     codeToRun
86     // }
87     TIntermBlock *newMainBody     = new TIntermBlock();
88     TIntermAggregate *oldMainCall = CreateInternalFunctionCallNode(
89         TType(EbtVoid), oldMainName.str().c_str(), oldMainId, new TIntermSequence());
90     newMainBody->appendStatement(oldMainCall);
91     newMainBody->appendStatement(codeToRun);
92 
93     // Add the new main() to the root node.
94     TIntermFunctionDefinition *newMain = new TIntermFunctionDefinition(newMainProto, newMainBody);
95     root->appendStatement(newMain);
96 }
97 
98 }  // anonymous namespace
99 
RunAtTheEndOfShader(TIntermBlock * root,TIntermNode * codeToRun,TSymbolTable * symbolTable)100 void RunAtTheEndOfShader(TIntermBlock *root, TIntermNode *codeToRun, TSymbolTable *symbolTable)
101 {
102     TIntermFunctionDefinition *main = FindMain(root);
103     if (!ContainsReturn(main))
104     {
105         main->getBody()->appendStatement(codeToRun);
106         return;
107     }
108 
109     WrapMainAndAppend(root, main, codeToRun, symbolTable);
110 }
111 
112 }  // namespace sh
113