1 // Copyright (c) 2019 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <vector>
16 
17 #include "test/opt/pass_fixture.h"
18 #include "test/opt/pass_utils.h"
19 
20 namespace spvtools {
21 namespace opt {
22 namespace {
23 
24 using SplitInvalidUnreachableTest = PassTest<::testing::Test>;
25 
26 std::string spirv_header = R"(OpCapability Shader
27 OpCapability VulkanMemoryModel
28 OpExtension "SPV_KHR_vulkan_memory_model"
29 OpMemoryModel Logical Vulkan
30 OpEntryPoint Vertex %1 "shader"
31 %uint = OpTypeInt 32 0
32 %uint_1 = OpConstant %uint 1
33 %uint_2 = OpConstant %uint 2
34 %void = OpTypeVoid
35 %bool = OpTypeBool
36 %7 = OpTypeFunction %void
37 )";
38 
39 std::string function_head = R"(%1 = OpFunction %void None %7
40 %8 = OpLabel
41 OpBranch %9
42 )";
43 
44 std::string function_tail = "OpFunctionEnd\n";
45 
GetLoopMergeBlock(std::string block_id,std::string merge_id,std::string continue_id,std::string body_id)46 std::string GetLoopMergeBlock(std::string block_id, std::string merge_id,
47                               std::string continue_id, std::string body_id) {
48   std::string result;
49   result += block_id + " = OpLabel\n";
50   result += "OpLoopMerge " + merge_id + " " + continue_id + " None\n";
51   result += "OpBranch " + body_id + "\n";
52   return result;
53 }
54 
GetSelectionMergeBlock(std::string block_id,std::string condition_id,std::string merge_id,std::string true_id,std::string false_id)55 std::string GetSelectionMergeBlock(std::string block_id,
56                                    std::string condition_id,
57                                    std::string merge_id, std::string true_id,
58                                    std::string false_id) {
59   std::string result;
60   result += block_id + " = OpLabel\n";
61   result += condition_id + " = OpSLessThan %bool %uint_1 %uint_2\n";
62   result += "OpSelectionMerge " + merge_id + " None\n";
63   result += "OpBranchConditional " + condition_id + " " + true_id + " " +
64             false_id + "\n";
65 
66   return result;
67 }
68 
GetReturnBlock(std::string block_id)69 std::string GetReturnBlock(std::string block_id) {
70   std::string result;
71   result += block_id + " = OpLabel\n";
72   result += "OpReturn\n";
73   return result;
74 }
75 
GetUnreachableBlock(std::string block_id)76 std::string GetUnreachableBlock(std::string block_id) {
77   std::string result;
78   result += block_id + " = OpLabel\n";
79   result += "OpUnreachable\n";
80   return result;
81 }
82 
GetBranchBlock(std::string block_id,std::string target_id)83 std::string GetBranchBlock(std::string block_id, std::string target_id) {
84   std::string result;
85   result += block_id + " = OpLabel\n";
86   result += "OpBranch " + target_id + "\n";
87   return result;
88 }
89 
TEST_F(SplitInvalidUnreachableTest,NoInvalidBlocks)90 TEST_F(SplitInvalidUnreachableTest, NoInvalidBlocks) {
91   std::string input = spirv_header + function_head;
92   input += GetLoopMergeBlock("%9", "%10", "%11", "%12");
93   input += GetSelectionMergeBlock("%12", "%13", "%14", "%15", "%16");
94   input += GetReturnBlock("%15");
95   input += GetReturnBlock("%16");
96   input += GetUnreachableBlock("%10");
97   input += GetBranchBlock("%11", "%9");
98   input += GetUnreachableBlock("%14");
99   input += function_tail;
100 
101   SinglePassRunAndCheck<SplitInvalidUnreachablePass>(input, input,
102                                                      /* skip_nop = */ false);
103 }
104 
TEST_F(SplitInvalidUnreachableTest,SelectionInLoop)105 TEST_F(SplitInvalidUnreachableTest, SelectionInLoop) {
106   std::string input = spirv_header + function_head;
107   input += GetLoopMergeBlock("%9", "%10", "%11", "%12");
108   input += GetSelectionMergeBlock("%12", "%13", "%11", "%15", "%16");
109   input += GetReturnBlock("%15");
110   input += GetReturnBlock("%16");
111   input += GetUnreachableBlock("%10");
112   input += GetBranchBlock("%11", "%9");
113   input += function_tail;
114 
115   std::string expected = spirv_header + function_head;
116   expected += GetLoopMergeBlock("%9", "%10", "%11", "%12");
117   expected += GetSelectionMergeBlock("%12", "%13", "%16", "%14", "%15");
118   expected += GetReturnBlock("%14");
119   expected += GetReturnBlock("%15");
120   expected += GetUnreachableBlock("%10");
121   expected += GetUnreachableBlock("%16");
122   expected += GetBranchBlock("%11", "%9");
123   expected += function_tail;
124 
125   SinglePassRunAndCheck<SplitInvalidUnreachablePass>(input, expected,
126                                                      /* skip_nop = */ false);
127 }
128 
TEST_F(SplitInvalidUnreachableTest,LoopInSelection)129 TEST_F(SplitInvalidUnreachableTest, LoopInSelection) {
130   std::string input = spirv_header + function_head;
131   input += GetSelectionMergeBlock("%9", "%10", "%11", "%12", "%13");
132   input += GetLoopMergeBlock("%12", "%14", "%11", "%15");
133   input += GetReturnBlock("%13");
134   input += GetUnreachableBlock("%14");
135   input += GetBranchBlock("%11", "%12");
136   input += GetReturnBlock("%15");
137   input += function_tail;
138 
139   std::string expected = spirv_header + function_head;
140   expected += GetSelectionMergeBlock("%9", "%10", "%16", "%12", "%13");
141   expected += GetLoopMergeBlock("%12", "%14", "%11", "%15");
142   expected += GetReturnBlock("%13");
143   expected += GetUnreachableBlock("%14");
144   expected += GetUnreachableBlock("%16");
145   expected += GetBranchBlock("%11", "%12");
146   expected += GetReturnBlock("%15");
147   expected += function_tail;
148 
149   SinglePassRunAndCheck<SplitInvalidUnreachablePass>(input, expected,
150                                                      /* skip_nop = */ false);
151 }
152 
153 }  // namespace
154 }  // namespace opt
155 }  // namespace spvtools
156