1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <map>
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "gmock/gmock.h"
21 #include "gtest/gtest.h"
22 #include "source/opt/build_module.h"
23 #include "source/opt/cfg.h"
24 #include "source/opt/ir_context.h"
25 #include "source/opt/pass.h"
26 #include "source/opt/propagator.h"
27 
28 namespace spvtools {
29 namespace opt {
30 namespace {
31 
32 using ::testing::UnorderedElementsAre;
33 
34 class PropagatorTest : public testing::Test {
35  protected:
TearDown()36   virtual void TearDown() {
37     ctx_.reset(nullptr);
38     values_.clear();
39     values_vec_.clear();
40   }
41 
Assemble(const std::string & input)42   void Assemble(const std::string& input) {
43     ctx_ = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, input);
44     ASSERT_NE(nullptr, ctx_) << "Assembling failed for shader:\n"
45                              << input << "\n";
46   }
47 
Propagate(const SSAPropagator::VisitFunction & visit_fn)48   bool Propagate(const SSAPropagator::VisitFunction& visit_fn) {
49     SSAPropagator propagator(ctx_.get(), visit_fn);
50     bool retval = false;
51     for (auto& fn : *ctx_->module()) {
52       retval |= propagator.Run(&fn);
53     }
54     return retval;
55   }
56 
GetValues()57   const std::vector<uint32_t>& GetValues() {
58     values_vec_.clear();
59     for (const auto& it : values_) {
60       values_vec_.push_back(it.second);
61     }
62     return values_vec_;
63   }
64 
65   std::unique_ptr<IRContext> ctx_;
66   std::map<uint32_t, uint32_t> values_;
67   std::vector<uint32_t> values_vec_;
68 };
69 
TEST_F(PropagatorTest,LocalPropagate)70 TEST_F(PropagatorTest, LocalPropagate) {
71   const std::string spv_asm = R"(
72                OpCapability Shader
73           %1 = OpExtInstImport "GLSL.std.450"
74                OpMemoryModel Logical GLSL450
75                OpEntryPoint Fragment %main "main" %outparm
76                OpExecutionMode %main OriginUpperLeft
77                OpSource GLSL 450
78                OpName %main "main"
79                OpName %x "x"
80                OpName %y "y"
81                OpName %z "z"
82                OpName %outparm "outparm"
83                OpDecorate %outparm Location 0
84        %void = OpTypeVoid
85           %3 = OpTypeFunction %void
86         %int = OpTypeInt 32 1
87 %_ptr_Function_int = OpTypePointer Function %int
88       %int_4 = OpConstant %int 4
89       %int_3 = OpConstant %int 3
90       %int_1 = OpConstant %int 1
91 %_ptr_Output_int = OpTypePointer Output %int
92     %outparm = OpVariable %_ptr_Output_int Output
93        %main = OpFunction %void None %3
94           %5 = OpLabel
95           %x = OpVariable %_ptr_Function_int Function
96           %y = OpVariable %_ptr_Function_int Function
97           %z = OpVariable %_ptr_Function_int Function
98                OpStore %x %int_4
99                OpStore %y %int_3
100                OpStore %z %int_1
101          %20 = OpLoad %int %z
102                OpStore %outparm %20
103                OpReturn
104                OpFunctionEnd
105                )";
106   Assemble(spv_asm);
107 
108   const auto visit_fn = [this](Instruction* instr, BasicBlock** dest_bb) {
109     *dest_bb = nullptr;
110     if (instr->opcode() == SpvOpStore) {
111       uint32_t lhs_id = instr->GetSingleWordOperand(0);
112       uint32_t rhs_id = instr->GetSingleWordOperand(1);
113       Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id);
114       if (rhs_def->opcode() == SpvOpConstant) {
115         uint32_t val = rhs_def->GetSingleWordOperand(2);
116         values_[lhs_id] = val;
117         return SSAPropagator::kInteresting;
118       }
119     }
120     return SSAPropagator::kVarying;
121   };
122 
123   EXPECT_TRUE(Propagate(visit_fn));
124   EXPECT_THAT(GetValues(), UnorderedElementsAre(4, 3, 1));
125 }
126 
TEST_F(PropagatorTest,PropagateThroughPhis)127 TEST_F(PropagatorTest, PropagateThroughPhis) {
128   const std::string spv_asm = R"(
129                OpCapability Shader
130           %1 = OpExtInstImport "GLSL.std.450"
131                OpMemoryModel Logical GLSL450
132                OpEntryPoint Fragment %main "main" %x %outparm
133                OpExecutionMode %main OriginUpperLeft
134                OpSource GLSL 450
135                OpName %main "main"
136                OpName %x "x"
137                OpName %outparm "outparm"
138                OpDecorate %x Flat
139                OpDecorate %x Location 0
140                OpDecorate %outparm Location 0
141        %void = OpTypeVoid
142           %3 = OpTypeFunction %void
143         %int = OpTypeInt 32 1
144        %bool = OpTypeBool
145 %_ptr_Function_int = OpTypePointer Function %int
146       %int_4 = OpConstant %int 4
147       %int_3 = OpConstant %int 3
148       %int_1 = OpConstant %int 1
149 %_ptr_Input_int = OpTypePointer Input %int
150           %x = OpVariable %_ptr_Input_int Input
151 %_ptr_Output_int = OpTypePointer Output %int
152     %outparm = OpVariable %_ptr_Output_int Output
153        %main = OpFunction %void None %3
154           %4 = OpLabel
155           %5 = OpLoad %int %x
156           %6 = OpSGreaterThan %bool %5 %int_3
157                OpSelectionMerge %25 None
158                OpBranchConditional %6 %22 %23
159          %22 = OpLabel
160           %7 = OpLoad %int %int_4
161                OpBranch %25
162          %23 = OpLabel
163           %8 = OpLoad %int %int_4
164                OpBranch %25
165          %25 = OpLabel
166          %35 = OpPhi %int %7 %22 %8 %23
167                OpStore %outparm %35
168                OpReturn
169                OpFunctionEnd
170                )";
171 
172   Assemble(spv_asm);
173 
174   Instruction* phi_instr = nullptr;
175   const auto visit_fn = [this, &phi_instr](Instruction* instr,
176                                            BasicBlock** dest_bb) {
177     *dest_bb = nullptr;
178     if (instr->opcode() == SpvOpLoad) {
179       uint32_t rhs_id = instr->GetSingleWordOperand(2);
180       Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id);
181       if (rhs_def->opcode() == SpvOpConstant) {
182         uint32_t val = rhs_def->GetSingleWordOperand(2);
183         values_[instr->result_id()] = val;
184         return SSAPropagator::kInteresting;
185       }
186     } else if (instr->opcode() == SpvOpPhi) {
187       phi_instr = instr;
188       SSAPropagator::PropStatus retval;
189       for (uint32_t i = 2; i < instr->NumOperands(); i += 2) {
190         uint32_t phi_arg_id = instr->GetSingleWordOperand(i);
191         auto it = values_.find(phi_arg_id);
192         if (it != values_.end()) {
193           EXPECT_EQ(it->second, 4u);
194           retval = SSAPropagator::kInteresting;
195           values_[instr->result_id()] = it->second;
196         } else {
197           retval = SSAPropagator::kNotInteresting;
198           break;
199         }
200       }
201       return retval;
202     }
203 
204     return SSAPropagator::kVarying;
205   };
206 
207   EXPECT_TRUE(Propagate(visit_fn));
208 
209   // The propagator should've concluded that the Phi instruction has a constant
210   // value of 4.
211   EXPECT_NE(phi_instr, nullptr);
212   EXPECT_EQ(values_[phi_instr->result_id()], 4u);
213 
214   EXPECT_THAT(GetValues(), UnorderedElementsAre(4u, 4u, 4u));
215 }
216 
217 }  // namespace
218 }  // namespace opt
219 }  // namespace spvtools
220