1 //
2 // Copyright (c) 2016 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Implementation of the integer pow expressions HLSL bug workaround.
7 // See header for more info.
8
9 #include "compiler/translator/ExpandIntegerPowExpressions.h"
10
11 #include <cmath>
12 #include <cstdlib>
13
14 #include "compiler/translator/IntermTraverse.h"
15
16 namespace sh
17 {
18
19 namespace
20 {
21
22 class Traverser : public TIntermTraverser
23 {
24 public:
25 static void Apply(TIntermNode *root, TSymbolTable *symbolTable);
26
27 private:
28 Traverser(TSymbolTable *symbolTable);
29 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
30 void nextIteration();
31
32 bool mFound = false;
33 };
34
35 // static
Apply(TIntermNode * root,TSymbolTable * symbolTable)36 void Traverser::Apply(TIntermNode *root, TSymbolTable *symbolTable)
37 {
38 Traverser traverser(symbolTable);
39 do
40 {
41 traverser.nextIteration();
42 root->traverse(&traverser);
43 if (traverser.mFound)
44 {
45 traverser.updateTree();
46 }
47 } while (traverser.mFound);
48 }
49
Traverser(TSymbolTable * symbolTable)50 Traverser::Traverser(TSymbolTable *symbolTable) : TIntermTraverser(true, false, false, symbolTable)
51 {
52 }
53
nextIteration()54 void Traverser::nextIteration()
55 {
56 mFound = false;
57 nextTemporaryId();
58 }
59
visitAggregate(Visit visit,TIntermAggregate * node)60 bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
61 {
62 if (mFound)
63 {
64 return false;
65 }
66
67 // Test 0: skip non-pow operators.
68 if (node->getOp() != EOpPow)
69 {
70 return true;
71 }
72
73 const TIntermSequence *sequence = node->getSequence();
74 ASSERT(sequence->size() == 2u);
75 const TIntermConstantUnion *constantNode = sequence->at(1)->getAsConstantUnion();
76
77 // Test 1: check for a single constant.
78 if (!constantNode || constantNode->getNominalSize() != 1)
79 {
80 return true;
81 }
82
83 const TConstantUnion *constant = constantNode->getUnionArrayPointer();
84
85 TConstantUnion asFloat;
86 asFloat.cast(EbtFloat, *constant);
87
88 float value = asFloat.getFConst();
89
90 // Test 2: value is in the problematic range.
91 if (value < -5.0f || value > 9.0f)
92 {
93 return true;
94 }
95
96 // Test 3: value is integer or pretty close to an integer.
97 float absval = std::abs(value);
98 float frac = absval - std::round(absval);
99 if (frac > 0.0001f)
100 {
101 return true;
102 }
103
104 // Test 4: skip -1, 0, and 1
105 int exponent = static_cast<int>(value);
106 int n = std::abs(exponent);
107 if (n < 2)
108 {
109 return true;
110 }
111
112 // Potential problem case detected, apply workaround.
113 nextTemporaryId();
114
115 TIntermTyped *lhs = sequence->at(0)->getAsTyped();
116 ASSERT(lhs);
117
118 TIntermDeclaration *init = createTempInitDeclaration(lhs);
119 TIntermTyped *current = createTempSymbol(lhs->getType());
120
121 insertStatementInParentBlock(init);
122
123 // Create a chain of n-1 multiples.
124 for (int i = 1; i < n; ++i)
125 {
126 TIntermBinary *mul = new TIntermBinary(EOpMul, current, createTempSymbol(lhs->getType()));
127 mul->setLine(node->getLine());
128 current = mul;
129 }
130
131 // For negative pow, compute the reciprocal of the positive pow.
132 if (exponent < 0)
133 {
134 TConstantUnion *oneVal = new TConstantUnion();
135 oneVal->setFConst(1.0f);
136 TIntermConstantUnion *oneNode = new TIntermConstantUnion(oneVal, node->getType());
137 TIntermBinary *div = new TIntermBinary(EOpDiv, oneNode, current);
138 current = div;
139 }
140
141 queueReplacement(current, OriginalNode::IS_DROPPED);
142 mFound = true;
143 return false;
144 }
145
146 } // anonymous namespace
147
ExpandIntegerPowExpressions(TIntermNode * root,TSymbolTable * symbolTable)148 void ExpandIntegerPowExpressions(TIntermNode *root, TSymbolTable *symbolTable)
149 {
150 Traverser::Apply(root, symbolTable);
151 }
152
153 } // namespace sh
154