1 //
2 // Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // The ArrayReturnValueToOutParameter function changes return values of an array type to out
7 // parameters in function definitions, prototypes, and call sites.
8
9 #include "compiler/translator/ArrayReturnValueToOutParameter.h"
10
11 #include <map>
12
13 #include "compiler/translator/IntermNode_util.h"
14 #include "compiler/translator/IntermTraverse.h"
15 #include "compiler/translator/StaticType.h"
16 #include "compiler/translator/SymbolTable.h"
17
18 namespace sh
19 {
20
21 namespace
22 {
23
24 constexpr const ImmutableString kReturnValueVariableName("angle_return");
25
CopyAggregateChildren(TIntermAggregateBase * from,TIntermAggregateBase * to)26 void CopyAggregateChildren(TIntermAggregateBase *from, TIntermAggregateBase *to)
27 {
28 const TIntermSequence *fromSequence = from->getSequence();
29 for (size_t ii = 0; ii < fromSequence->size(); ++ii)
30 {
31 to->getSequence()->push_back(fromSequence->at(ii));
32 }
33 }
34
35 class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser
36 {
37 public:
38 static void apply(TIntermNode *root, TSymbolTable *symbolTable);
39
40 private:
41 ArrayReturnValueToOutParameterTraverser(TSymbolTable *symbolTable);
42
43 bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override;
44 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
45 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
46 bool visitBranch(Visit visit, TIntermBranch *node) override;
47 bool visitBinary(Visit visit, TIntermBinary *node) override;
48
49 TIntermAggregate *createReplacementCall(TIntermAggregate *originalCall,
50 TIntermTyped *returnValueTarget);
51
52 // Set when traversal is inside a function with array return value.
53 TIntermFunctionDefinition *mFunctionWithArrayReturnValue;
54
55 struct ChangedFunction
56 {
57 const TVariable *returnValueVariable;
58 const TFunction *func;
59 };
60
61 // Map from function symbol ids to the changed function.
62 std::map<int, ChangedFunction> mChangedFunctions;
63 };
64
createReplacementCall(TIntermAggregate * originalCall,TIntermTyped * returnValueTarget)65 TIntermAggregate *ArrayReturnValueToOutParameterTraverser::createReplacementCall(
66 TIntermAggregate *originalCall,
67 TIntermTyped *returnValueTarget)
68 {
69 TIntermSequence *replacementArguments = new TIntermSequence();
70 TIntermSequence *originalArguments = originalCall->getSequence();
71 for (auto &arg : *originalArguments)
72 {
73 replacementArguments->push_back(arg);
74 }
75 replacementArguments->push_back(returnValueTarget);
76 ASSERT(originalCall->getFunction());
77 const TSymbolUniqueId &originalId = originalCall->getFunction()->uniqueId();
78 TIntermAggregate *replacementCall = TIntermAggregate::CreateFunctionCall(
79 *mChangedFunctions[originalId.get()].func, replacementArguments);
80 replacementCall->setLine(originalCall->getLine());
81 return replacementCall;
82 }
83
apply(TIntermNode * root,TSymbolTable * symbolTable)84 void ArrayReturnValueToOutParameterTraverser::apply(TIntermNode *root, TSymbolTable *symbolTable)
85 {
86 ArrayReturnValueToOutParameterTraverser arrayReturnValueToOutParam(symbolTable);
87 root->traverse(&arrayReturnValueToOutParam);
88 arrayReturnValueToOutParam.updateTree();
89 }
90
ArrayReturnValueToOutParameterTraverser(TSymbolTable * symbolTable)91 ArrayReturnValueToOutParameterTraverser::ArrayReturnValueToOutParameterTraverser(
92 TSymbolTable *symbolTable)
93 : TIntermTraverser(true, false, true, symbolTable), mFunctionWithArrayReturnValue(nullptr)
94 {
95 }
96
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)97 bool ArrayReturnValueToOutParameterTraverser::visitFunctionDefinition(
98 Visit visit,
99 TIntermFunctionDefinition *node)
100 {
101 if (node->getFunctionPrototype()->isArray() && visit == PreVisit)
102 {
103 // Replacing the function header is done on visitFunctionPrototype().
104 mFunctionWithArrayReturnValue = node;
105 }
106 if (visit == PostVisit)
107 {
108 mFunctionWithArrayReturnValue = nullptr;
109 }
110 return true;
111 }
112
visitFunctionPrototype(Visit visit,TIntermFunctionPrototype * node)113 bool ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(Visit visit,
114 TIntermFunctionPrototype *node)
115 {
116 if (visit == PreVisit && node->isArray())
117 {
118 // Replace the whole prototype node with another node that has the out parameter
119 // added. Also set the function to return void.
120 const TSymbolUniqueId &functionId = node->getFunction()->uniqueId();
121 if (mChangedFunctions.find(functionId.get()) == mChangedFunctions.end())
122 {
123 TType *returnValueVariableType = new TType(node->getType());
124 returnValueVariableType->setQualifier(EvqOut);
125 ChangedFunction changedFunction;
126 changedFunction.returnValueVariable =
127 new TVariable(mSymbolTable, kReturnValueVariableName, returnValueVariableType,
128 SymbolType::AngleInternal);
129 TFunction *func = new TFunction(mSymbolTable, node->getFunction()->name(),
130 node->getFunction()->symbolType(),
131 StaticType::GetBasic<EbtVoid>(), false);
132 for (size_t i = 0; i < node->getFunction()->getParamCount(); ++i)
133 {
134 func->addParameter(node->getFunction()->getParam(i));
135 }
136 func->addParameter(TConstParameter(
137 kReturnValueVariableName, static_cast<const TType *>(returnValueVariableType)));
138 changedFunction.func = func;
139 mChangedFunctions[functionId.get()] = changedFunction;
140 }
141 TIntermFunctionPrototype *replacement =
142 new TIntermFunctionPrototype(mChangedFunctions[functionId.get()].func);
143 CopyAggregateChildren(node, replacement);
144 replacement->getSequence()->push_back(
145 new TIntermSymbol(mChangedFunctions[functionId.get()].returnValueVariable));
146 replacement->setLine(node->getLine());
147
148 queueReplacement(replacement, OriginalNode::IS_DROPPED);
149 }
150 return false;
151 }
152
visitAggregate(Visit visit,TIntermAggregate * node)153 bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
154 {
155 ASSERT(!node->isArray() || node->getOp() != EOpCallInternalRawFunction);
156 if (visit == PreVisit && node->isArray() && node->getOp() == EOpCallFunctionInAST)
157 {
158 // Handle call sites where the returned array is not assigned.
159 // Examples where f() is a function returning an array:
160 // 1. f();
161 // 2. another_array == f();
162 // 3. another_function(f());
163 // 4. return f();
164 // Cases 2 to 4 are already converted to simpler cases by
165 // SeparateExpressionsReturningArrays, so we only need to worry about the case where a
166 // function call returning an array forms an expression by itself.
167 TIntermBlock *parentBlock = getParentNode()->getAsBlock();
168 if (parentBlock)
169 {
170 // replace
171 // f();
172 // with
173 // type s0[size]; f(s0);
174 TIntermSequence replacements;
175
176 // type s0[size];
177 TIntermDeclaration *returnValueDeclaration = nullptr;
178 TVariable *returnValue = DeclareTempVariable(mSymbolTable, new TType(node->getType()),
179 EvqTemporary, &returnValueDeclaration);
180 replacements.push_back(returnValueDeclaration);
181
182 // f(s0);
183 TIntermSymbol *returnValueSymbol = CreateTempSymbolNode(returnValue);
184 replacements.push_back(createReplacementCall(node, returnValueSymbol));
185 mMultiReplacements.push_back(
186 NodeReplaceWithMultipleEntry(parentBlock, node, replacements));
187 }
188 return false;
189 }
190 return true;
191 }
192
visitBranch(Visit visit,TIntermBranch * node)193 bool ArrayReturnValueToOutParameterTraverser::visitBranch(Visit visit, TIntermBranch *node)
194 {
195 if (mFunctionWithArrayReturnValue && node->getFlowOp() == EOpReturn)
196 {
197 // Instead of returning a value, assign to the out parameter and then return.
198 TIntermSequence replacements;
199
200 TIntermTyped *expression = node->getExpression();
201 ASSERT(expression != nullptr);
202 const TSymbolUniqueId &functionId =
203 mFunctionWithArrayReturnValue->getFunction()->uniqueId();
204 ASSERT(mChangedFunctions.find(functionId.get()) != mChangedFunctions.end());
205 TIntermSymbol *returnValueSymbol =
206 new TIntermSymbol(mChangedFunctions[functionId.get()].returnValueVariable);
207 TIntermBinary *replacementAssignment =
208 new TIntermBinary(EOpAssign, returnValueSymbol, expression);
209 replacementAssignment->setLine(expression->getLine());
210 replacements.push_back(replacementAssignment);
211
212 TIntermBranch *replacementBranch = new TIntermBranch(EOpReturn, nullptr);
213 replacementBranch->setLine(node->getLine());
214 replacements.push_back(replacementBranch);
215
216 mMultiReplacements.push_back(
217 NodeReplaceWithMultipleEntry(getParentNode()->getAsBlock(), node, replacements));
218 }
219 return false;
220 }
221
visitBinary(Visit visit,TIntermBinary * node)222 bool ArrayReturnValueToOutParameterTraverser::visitBinary(Visit visit, TIntermBinary *node)
223 {
224 if (node->getOp() == EOpAssign && node->getLeft()->isArray())
225 {
226 TIntermAggregate *rightAgg = node->getRight()->getAsAggregate();
227 ASSERT(rightAgg == nullptr || rightAgg->getOp() != EOpCallInternalRawFunction);
228 if (rightAgg != nullptr && rightAgg->getOp() == EOpCallFunctionInAST)
229 {
230 TIntermAggregate *replacementCall = createReplacementCall(rightAgg, node->getLeft());
231 queueReplacement(replacementCall, OriginalNode::IS_DROPPED);
232 }
233 }
234 return false;
235 }
236
237 } // namespace
238
ArrayReturnValueToOutParameter(TIntermNode * root,TSymbolTable * symbolTable)239 void ArrayReturnValueToOutParameter(TIntermNode *root, TSymbolTable *symbolTable)
240 {
241 ArrayReturnValueToOutParameterTraverser::apply(root, symbolTable);
242 }
243
244 } // namespace sh
245