1 //
2 // Copyright (c) 2017 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RunAtTheEndOfShader.cpp: Add code to be run at the end of the shader. In case main() contains a
7 // return statement, this is done by replacing the main() function with another function that calls
8 // the old main, like this:
9 //
10 // void main() { body }
11 // =>
12 // void main0() { body }
13 // void main()
14 // {
15 // main0();
16 // codeToRun
17 // }
18 //
19 // This way the code will get run even if the return statement inside main is executed.
20 //
21
22 #include "compiler/translator/RunAtTheEndOfShader.h"
23
24 #include "compiler/translator/FindMain.h"
25 #include "compiler/translator/IntermNode.h"
26 #include "compiler/translator/IntermNode_util.h"
27 #include "compiler/translator/IntermTraverse.h"
28 #include "compiler/translator/SymbolTable.h"
29
30 namespace sh
31 {
32
33 namespace
34 {
35
36 class ContainsReturnTraverser : public TIntermTraverser
37 {
38 public:
ContainsReturnTraverser()39 ContainsReturnTraverser() : TIntermTraverser(true, false, false), mContainsReturn(false) {}
40
visitBranch(Visit visit,TIntermBranch * node)41 bool visitBranch(Visit visit, TIntermBranch *node) override
42 {
43 if (node->getFlowOp() == EOpReturn)
44 {
45 mContainsReturn = true;
46 }
47 return false;
48 }
49
containsReturn()50 bool containsReturn() { return mContainsReturn; }
51
52 private:
53 bool mContainsReturn;
54 };
55
ContainsReturn(TIntermNode * node)56 bool ContainsReturn(TIntermNode *node)
57 {
58 ContainsReturnTraverser traverser;
59 node->traverse(&traverser);
60 return traverser.containsReturn();
61 }
62
WrapMainAndAppend(TIntermBlock * root,TIntermFunctionDefinition * main,TIntermNode * codeToRun,TSymbolTable * symbolTable)63 void WrapMainAndAppend(TIntermBlock *root,
64 TIntermFunctionDefinition *main,
65 TIntermNode *codeToRun,
66 TSymbolTable *symbolTable)
67 {
68 // Replace main() with main0() with the same body.
69 TSymbolUniqueId oldMainId(symbolTable);
70 std::stringstream oldMainName;
71 oldMainName << "main" << oldMainId.get();
72 TIntermFunctionDefinition *oldMain = CreateInternalFunctionDefinitionNode(
73 TType(EbtVoid), oldMainName.str().c_str(), main->getBody(), oldMainId);
74
75 bool replaced = root->replaceChildNode(main, oldMain);
76 ASSERT(replaced);
77
78 // void main()
79 TIntermFunctionPrototype *newMainProto = new TIntermFunctionPrototype(
80 TType(EbtVoid), main->getFunctionPrototype()->getFunctionSymbolInfo()->getId());
81 newMainProto->getFunctionSymbolInfo()->setName("main");
82
83 // {
84 // main0();
85 // codeToRun
86 // }
87 TIntermBlock *newMainBody = new TIntermBlock();
88 TIntermAggregate *oldMainCall = CreateInternalFunctionCallNode(
89 TType(EbtVoid), oldMainName.str().c_str(), oldMainId, new TIntermSequence());
90 newMainBody->appendStatement(oldMainCall);
91 newMainBody->appendStatement(codeToRun);
92
93 // Add the new main() to the root node.
94 TIntermFunctionDefinition *newMain = new TIntermFunctionDefinition(newMainProto, newMainBody);
95 root->appendStatement(newMain);
96 }
97
98 } // anonymous namespace
99
RunAtTheEndOfShader(TIntermBlock * root,TIntermNode * codeToRun,TSymbolTable * symbolTable)100 void RunAtTheEndOfShader(TIntermBlock *root, TIntermNode *codeToRun, TSymbolTable *symbolTable)
101 {
102 TIntermFunctionDefinition *main = FindMain(root);
103 if (!ContainsReturn(main))
104 {
105 main->getBody()->appendStatement(codeToRun);
106 return;
107 }
108
109 WrapMainAndAppend(root, main, codeToRun, symbolTable);
110 }
111
112 } // namespace sh
113