1 //
2 // Copyright (c) 2016 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Implementation of the integer pow expressions HLSL bug workaround.
7 // See header for more info.
8 
9 #include "compiler/translator/ExpandIntegerPowExpressions.h"
10 
11 #include <cmath>
12 #include <cstdlib>
13 
14 #include "compiler/translator/IntermTraverse.h"
15 
16 namespace sh
17 {
18 
19 namespace
20 {
21 
22 class Traverser : public TIntermTraverser
23 {
24   public:
25     static void Apply(TIntermNode *root, TSymbolTable *symbolTable);
26 
27   private:
28     Traverser(TSymbolTable *symbolTable);
29     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
30     void nextIteration();
31 
32     bool mFound = false;
33 };
34 
35 // static
Apply(TIntermNode * root,TSymbolTable * symbolTable)36 void Traverser::Apply(TIntermNode *root, TSymbolTable *symbolTable)
37 {
38     Traverser traverser(symbolTable);
39     do
40     {
41         traverser.nextIteration();
42         root->traverse(&traverser);
43         if (traverser.mFound)
44         {
45             traverser.updateTree();
46         }
47     } while (traverser.mFound);
48 }
49 
Traverser(TSymbolTable * symbolTable)50 Traverser::Traverser(TSymbolTable *symbolTable) : TIntermTraverser(true, false, false, symbolTable)
51 {
52 }
53 
nextIteration()54 void Traverser::nextIteration()
55 {
56     mFound = false;
57     nextTemporaryId();
58 }
59 
visitAggregate(Visit visit,TIntermAggregate * node)60 bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
61 {
62     if (mFound)
63     {
64         return false;
65     }
66 
67     // Test 0: skip non-pow operators.
68     if (node->getOp() != EOpPow)
69     {
70         return true;
71     }
72 
73     const TIntermSequence *sequence = node->getSequence();
74     ASSERT(sequence->size() == 2u);
75     const TIntermConstantUnion *constantNode = sequence->at(1)->getAsConstantUnion();
76 
77     // Test 1: check for a single constant.
78     if (!constantNode || constantNode->getNominalSize() != 1)
79     {
80         return true;
81     }
82 
83     const TConstantUnion *constant = constantNode->getUnionArrayPointer();
84 
85     TConstantUnion asFloat;
86     asFloat.cast(EbtFloat, *constant);
87 
88     float value = asFloat.getFConst();
89 
90     // Test 2: value is in the problematic range.
91     if (value < -5.0f || value > 9.0f)
92     {
93         return true;
94     }
95 
96     // Test 3: value is integer or pretty close to an integer.
97     float absval = std::abs(value);
98     float frac   = absval - std::round(absval);
99     if (frac > 0.0001f)
100     {
101         return true;
102     }
103 
104     // Test 4: skip -1, 0, and 1
105     int exponent = static_cast<int>(value);
106     int n        = std::abs(exponent);
107     if (n < 2)
108     {
109         return true;
110     }
111 
112     // Potential problem case detected, apply workaround.
113     nextTemporaryId();
114 
115     TIntermTyped *lhs = sequence->at(0)->getAsTyped();
116     ASSERT(lhs);
117 
118     TIntermDeclaration *init = createTempInitDeclaration(lhs);
119     TIntermTyped *current    = createTempSymbol(lhs->getType());
120 
121     insertStatementInParentBlock(init);
122 
123     // Create a chain of n-1 multiples.
124     for (int i = 1; i < n; ++i)
125     {
126         TIntermBinary *mul = new TIntermBinary(EOpMul, current, createTempSymbol(lhs->getType()));
127         mul->setLine(node->getLine());
128         current = mul;
129     }
130 
131     // For negative pow, compute the reciprocal of the positive pow.
132     if (exponent < 0)
133     {
134         TConstantUnion *oneVal = new TConstantUnion();
135         oneVal->setFConst(1.0f);
136         TIntermConstantUnion *oneNode = new TIntermConstantUnion(oneVal, node->getType());
137         TIntermBinary *div            = new TIntermBinary(EOpDiv, oneNode, current);
138         current                       = div;
139     }
140 
141     queueReplacement(current, OriginalNode::IS_DROPPED);
142     mFound = true;
143     return false;
144 }
145 
146 }  // anonymous namespace
147 
ExpandIntegerPowExpressions(TIntermNode * root,TSymbolTable * symbolTable)148 void ExpandIntegerPowExpressions(TIntermNode *root, TSymbolTable *symbolTable)
149 {
150     Traverser::Apply(root, symbolTable);
151 }
152 
153 }  // namespace sh
154