1 //
2 // Copyright 2018 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // AtomicCounterFunctionHLSL: Class for writing implementation of atomic counter functions into HLSL
7 // output.
8 //
9 
10 #include "compiler/translator/AtomicCounterFunctionHLSL.h"
11 
12 #include "compiler/translator/Common.h"
13 #include "compiler/translator/ImmutableStringBuilder.h"
14 #include "compiler/translator/InfoSink.h"
15 #include "compiler/translator/IntermNode.h"
16 
17 namespace sh
18 {
19 
20 namespace
21 {
22 constexpr ImmutableString kAtomicCounter("atomicCounter");
23 constexpr ImmutableString kAtomicCounterIncrement("atomicCounterIncrement");
24 constexpr ImmutableString kAtomicCounterDecrement("atomicCounterDecrement");
25 constexpr ImmutableString kAtomicCounterBaseName("_acbase_");
26 }  // namespace
27 
AtomicCounterFunctionHLSL(bool forceResolution)28 AtomicCounterFunctionHLSL::AtomicCounterFunctionHLSL(bool forceResolution)
29     : mForceResolution(forceResolution)
30 {}
31 
useAtomicCounterFunction(const ImmutableString & name)32 ImmutableString AtomicCounterFunctionHLSL::useAtomicCounterFunction(const ImmutableString &name)
33 {
34     // The largest string that will be create created is "_acbase_increment" or "_acbase_decrement"
35     ImmutableStringBuilder hlslFunctionNameSB(kAtomicCounterBaseName.length() +
36                                               strlen("increment"));
37     hlslFunctionNameSB << kAtomicCounterBaseName;
38 
39     AtomicCounterFunction atomicMethod;
40     if (kAtomicCounter == name)
41     {
42         atomicMethod = AtomicCounterFunction::LOAD;
43         hlslFunctionNameSB << "load";
44     }
45     else if (kAtomicCounterIncrement == name)
46     {
47         atomicMethod = AtomicCounterFunction::INCREMENT;
48         hlslFunctionNameSB << "increment";
49     }
50     else if (kAtomicCounterDecrement == name)
51     {
52         atomicMethod = AtomicCounterFunction::DECREMENT;
53         hlslFunctionNameSB << "decrement";
54     }
55     else
56     {
57         atomicMethod = AtomicCounterFunction::INVALID;
58         UNREACHABLE();
59     }
60 
61     ImmutableString hlslFunctionName(hlslFunctionNameSB);
62     mAtomicCounterFunctions[hlslFunctionName] = atomicMethod;
63 
64     return hlslFunctionName;
65 }
66 
atomicCounterFunctionHeader(TInfoSinkBase & out)67 void AtomicCounterFunctionHLSL::atomicCounterFunctionHeader(TInfoSinkBase &out)
68 {
69     for (auto &atomicFunction : mAtomicCounterFunctions)
70     {
71         out << "uint " << atomicFunction.first
72             << "(in RWByteAddressBuffer counter, int address)\n"
73                "{\n"
74                "    uint ret;\n";
75 
76         switch (atomicFunction.second)
77         {
78             case AtomicCounterFunction::INCREMENT:
79                 out << "    counter.InterlockedAdd(address, 1u, ret);\n";
80                 break;
81             case AtomicCounterFunction::DECREMENT:
82                 out << "    counter.InterlockedAdd(address, 0u - 1u, ret);\n"
83                        "    ret -= 1u;\n";  // atomicCounterDecrement is a post-decrement op
84                 break;
85             case AtomicCounterFunction::LOAD:
86                 out << "    ret = counter.Load(address);\n";
87                 break;
88             default:
89                 UNREACHABLE();
90                 break;
91         }
92 
93         if (mForceResolution && atomicFunction.second != AtomicCounterFunction::LOAD)
94         {
95             out << "    if (ret == 0) {\n"
96                    "        ret = 0 - ret;\n"
97                    "    }\n";
98         }
99 
100         out << "    return ret;\n"
101                "}\n\n";
102     }
103 }
104 
getAtomicCounterNameForBinding(int binding)105 ImmutableString getAtomicCounterNameForBinding(int binding)
106 {
107     std::stringstream counterName = sh::InitializeStream<std::stringstream>();
108     counterName << kAtomicCounterBaseName << binding;
109     return ImmutableString(counterName.str());
110 }
111 
112 }  // namespace sh
113