1 //
2 // Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // The ArrayReturnValueToOutParameter function changes return values of an array type to out
7 // parameters in function definitions, prototypes, and call sites.
8 
9 #include "compiler/translator/ArrayReturnValueToOutParameter.h"
10 
11 #include <map>
12 
13 #include "compiler/translator/IntermNode_util.h"
14 #include "compiler/translator/IntermTraverse.h"
15 #include "compiler/translator/StaticType.h"
16 #include "compiler/translator/SymbolTable.h"
17 
18 namespace sh
19 {
20 
21 namespace
22 {
23 
24 constexpr const ImmutableString kReturnValueVariableName("angle_return");
25 
CopyAggregateChildren(TIntermAggregateBase * from,TIntermAggregateBase * to)26 void CopyAggregateChildren(TIntermAggregateBase *from, TIntermAggregateBase *to)
27 {
28     const TIntermSequence *fromSequence = from->getSequence();
29     for (size_t ii = 0; ii < fromSequence->size(); ++ii)
30     {
31         to->getSequence()->push_back(fromSequence->at(ii));
32     }
33 }
34 
35 class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser
36 {
37   public:
38     static void apply(TIntermNode *root, TSymbolTable *symbolTable);
39 
40   private:
41     ArrayReturnValueToOutParameterTraverser(TSymbolTable *symbolTable);
42 
43     bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override;
44     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
45     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
46     bool visitBranch(Visit visit, TIntermBranch *node) override;
47     bool visitBinary(Visit visit, TIntermBinary *node) override;
48 
49     TIntermAggregate *createReplacementCall(TIntermAggregate *originalCall,
50                                             TIntermTyped *returnValueTarget);
51 
52     // Set when traversal is inside a function with array return value.
53     TIntermFunctionDefinition *mFunctionWithArrayReturnValue;
54 
55     struct ChangedFunction
56     {
57         const TVariable *returnValueVariable;
58         const TFunction *func;
59     };
60 
61     // Map from function symbol ids to the changed function.
62     std::map<int, ChangedFunction> mChangedFunctions;
63 };
64 
createReplacementCall(TIntermAggregate * originalCall,TIntermTyped * returnValueTarget)65 TIntermAggregate *ArrayReturnValueToOutParameterTraverser::createReplacementCall(
66     TIntermAggregate *originalCall,
67     TIntermTyped *returnValueTarget)
68 {
69     TIntermSequence *replacementArguments = new TIntermSequence();
70     TIntermSequence *originalArguments    = originalCall->getSequence();
71     for (auto &arg : *originalArguments)
72     {
73         replacementArguments->push_back(arg);
74     }
75     replacementArguments->push_back(returnValueTarget);
76     ASSERT(originalCall->getFunction());
77     const TSymbolUniqueId &originalId = originalCall->getFunction()->uniqueId();
78     TIntermAggregate *replacementCall = TIntermAggregate::CreateFunctionCall(
79         *mChangedFunctions[originalId.get()].func, replacementArguments);
80     replacementCall->setLine(originalCall->getLine());
81     return replacementCall;
82 }
83 
apply(TIntermNode * root,TSymbolTable * symbolTable)84 void ArrayReturnValueToOutParameterTraverser::apply(TIntermNode *root, TSymbolTable *symbolTable)
85 {
86     ArrayReturnValueToOutParameterTraverser arrayReturnValueToOutParam(symbolTable);
87     root->traverse(&arrayReturnValueToOutParam);
88     arrayReturnValueToOutParam.updateTree();
89 }
90 
ArrayReturnValueToOutParameterTraverser(TSymbolTable * symbolTable)91 ArrayReturnValueToOutParameterTraverser::ArrayReturnValueToOutParameterTraverser(
92     TSymbolTable *symbolTable)
93     : TIntermTraverser(true, false, true, symbolTable), mFunctionWithArrayReturnValue(nullptr)
94 {
95 }
96 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)97 bool ArrayReturnValueToOutParameterTraverser::visitFunctionDefinition(
98     Visit visit,
99     TIntermFunctionDefinition *node)
100 {
101     if (node->getFunctionPrototype()->isArray() && visit == PreVisit)
102     {
103         // Replacing the function header is done on visitFunctionPrototype().
104         mFunctionWithArrayReturnValue = node;
105     }
106     if (visit == PostVisit)
107     {
108         mFunctionWithArrayReturnValue = nullptr;
109     }
110     return true;
111 }
112 
visitFunctionPrototype(Visit visit,TIntermFunctionPrototype * node)113 bool ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(Visit visit,
114                                                                      TIntermFunctionPrototype *node)
115 {
116     if (visit == PreVisit && node->isArray())
117     {
118         // Replace the whole prototype node with another node that has the out parameter
119         // added. Also set the function to return void.
120         const TSymbolUniqueId &functionId = node->getFunction()->uniqueId();
121         if (mChangedFunctions.find(functionId.get()) == mChangedFunctions.end())
122         {
123             TType *returnValueVariableType = new TType(node->getType());
124             returnValueVariableType->setQualifier(EvqOut);
125             ChangedFunction changedFunction;
126             changedFunction.returnValueVariable =
127                 new TVariable(mSymbolTable, kReturnValueVariableName, returnValueVariableType,
128                               SymbolType::AngleInternal);
129             TFunction *func = new TFunction(mSymbolTable, node->getFunction()->name(),
130                                             node->getFunction()->symbolType(),
131                                             StaticType::GetBasic<EbtVoid>(), false);
132             for (size_t i = 0; i < node->getFunction()->getParamCount(); ++i)
133             {
134                 func->addParameter(node->getFunction()->getParam(i));
135             }
136             func->addParameter(TConstParameter(
137                 kReturnValueVariableName, static_cast<const TType *>(returnValueVariableType)));
138             changedFunction.func                = func;
139             mChangedFunctions[functionId.get()] = changedFunction;
140         }
141         TIntermFunctionPrototype *replacement =
142             new TIntermFunctionPrototype(mChangedFunctions[functionId.get()].func);
143         CopyAggregateChildren(node, replacement);
144         replacement->getSequence()->push_back(
145             new TIntermSymbol(mChangedFunctions[functionId.get()].returnValueVariable));
146         replacement->setLine(node->getLine());
147 
148         queueReplacement(replacement, OriginalNode::IS_DROPPED);
149     }
150     return false;
151 }
152 
visitAggregate(Visit visit,TIntermAggregate * node)153 bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
154 {
155     ASSERT(!node->isArray() || node->getOp() != EOpCallInternalRawFunction);
156     if (visit == PreVisit && node->isArray() && node->getOp() == EOpCallFunctionInAST)
157     {
158         // Handle call sites where the returned array is not assigned.
159         // Examples where f() is a function returning an array:
160         // 1. f();
161         // 2. another_array == f();
162         // 3. another_function(f());
163         // 4. return f();
164         // Cases 2 to 4 are already converted to simpler cases by
165         // SeparateExpressionsReturningArrays, so we only need to worry about the case where a
166         // function call returning an array forms an expression by itself.
167         TIntermBlock *parentBlock = getParentNode()->getAsBlock();
168         if (parentBlock)
169         {
170             // replace
171             //   f();
172             // with
173             //   type s0[size]; f(s0);
174             TIntermSequence replacements;
175 
176             // type s0[size];
177             TIntermDeclaration *returnValueDeclaration = nullptr;
178             TVariable *returnValue = DeclareTempVariable(mSymbolTable, new TType(node->getType()),
179                                                          EvqTemporary, &returnValueDeclaration);
180             replacements.push_back(returnValueDeclaration);
181 
182             // f(s0);
183             TIntermSymbol *returnValueSymbol = CreateTempSymbolNode(returnValue);
184             replacements.push_back(createReplacementCall(node, returnValueSymbol));
185             mMultiReplacements.push_back(
186                 NodeReplaceWithMultipleEntry(parentBlock, node, replacements));
187         }
188         return false;
189     }
190     return true;
191 }
192 
visitBranch(Visit visit,TIntermBranch * node)193 bool ArrayReturnValueToOutParameterTraverser::visitBranch(Visit visit, TIntermBranch *node)
194 {
195     if (mFunctionWithArrayReturnValue && node->getFlowOp() == EOpReturn)
196     {
197         // Instead of returning a value, assign to the out parameter and then return.
198         TIntermSequence replacements;
199 
200         TIntermTyped *expression = node->getExpression();
201         ASSERT(expression != nullptr);
202         const TSymbolUniqueId &functionId =
203             mFunctionWithArrayReturnValue->getFunction()->uniqueId();
204         ASSERT(mChangedFunctions.find(functionId.get()) != mChangedFunctions.end());
205         TIntermSymbol *returnValueSymbol =
206             new TIntermSymbol(mChangedFunctions[functionId.get()].returnValueVariable);
207         TIntermBinary *replacementAssignment =
208             new TIntermBinary(EOpAssign, returnValueSymbol, expression);
209         replacementAssignment->setLine(expression->getLine());
210         replacements.push_back(replacementAssignment);
211 
212         TIntermBranch *replacementBranch = new TIntermBranch(EOpReturn, nullptr);
213         replacementBranch->setLine(node->getLine());
214         replacements.push_back(replacementBranch);
215 
216         mMultiReplacements.push_back(
217             NodeReplaceWithMultipleEntry(getParentNode()->getAsBlock(), node, replacements));
218     }
219     return false;
220 }
221 
visitBinary(Visit visit,TIntermBinary * node)222 bool ArrayReturnValueToOutParameterTraverser::visitBinary(Visit visit, TIntermBinary *node)
223 {
224     if (node->getOp() == EOpAssign && node->getLeft()->isArray())
225     {
226         TIntermAggregate *rightAgg = node->getRight()->getAsAggregate();
227         ASSERT(rightAgg == nullptr || rightAgg->getOp() != EOpCallInternalRawFunction);
228         if (rightAgg != nullptr && rightAgg->getOp() == EOpCallFunctionInAST)
229         {
230             TIntermAggregate *replacementCall = createReplacementCall(rightAgg, node->getLeft());
231             queueReplacement(replacementCall, OriginalNode::IS_DROPPED);
232         }
233     }
234     return false;
235 }
236 
237 }  // namespace
238 
ArrayReturnValueToOutParameter(TIntermNode * root,TSymbolTable * symbolTable)239 void ArrayReturnValueToOutParameter(TIntermNode *root, TSymbolTable *symbolTable)
240 {
241     ArrayReturnValueToOutParameterTraverser::apply(root, symbolTable);
242 }
243 
244 }  // namespace sh
245