1 // Copyright (c) 2019 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <vector>
16
17 #include "test/opt/pass_fixture.h"
18 #include "test/opt/pass_utils.h"
19
20 namespace spvtools {
21 namespace opt {
22 namespace {
23
24 using SplitInvalidUnreachableTest = PassTest<::testing::Test>;
25
26 std::string spirv_header = R"(OpCapability Shader
27 OpCapability VulkanMemoryModel
28 OpExtension "SPV_KHR_vulkan_memory_model"
29 OpMemoryModel Logical Vulkan
30 OpEntryPoint Vertex %1 "shader"
31 %uint = OpTypeInt 32 0
32 %uint_1 = OpConstant %uint 1
33 %uint_2 = OpConstant %uint 2
34 %void = OpTypeVoid
35 %bool = OpTypeBool
36 %7 = OpTypeFunction %void
37 )";
38
39 std::string function_head = R"(%1 = OpFunction %void None %7
40 %8 = OpLabel
41 OpBranch %9
42 )";
43
44 std::string function_tail = "OpFunctionEnd\n";
45
GetLoopMergeBlock(std::string block_id,std::string merge_id,std::string continue_id,std::string body_id)46 std::string GetLoopMergeBlock(std::string block_id, std::string merge_id,
47 std::string continue_id, std::string body_id) {
48 std::string result;
49 result += block_id + " = OpLabel\n";
50 result += "OpLoopMerge " + merge_id + " " + continue_id + " None\n";
51 result += "OpBranch " + body_id + "\n";
52 return result;
53 }
54
GetSelectionMergeBlock(std::string block_id,std::string condition_id,std::string merge_id,std::string true_id,std::string false_id)55 std::string GetSelectionMergeBlock(std::string block_id,
56 std::string condition_id,
57 std::string merge_id, std::string true_id,
58 std::string false_id) {
59 std::string result;
60 result += block_id + " = OpLabel\n";
61 result += condition_id + " = OpSLessThan %bool %uint_1 %uint_2\n";
62 result += "OpSelectionMerge " + merge_id + " None\n";
63 result += "OpBranchConditional " + condition_id + " " + true_id + " " +
64 false_id + "\n";
65
66 return result;
67 }
68
GetReturnBlock(std::string block_id)69 std::string GetReturnBlock(std::string block_id) {
70 std::string result;
71 result += block_id + " = OpLabel\n";
72 result += "OpReturn\n";
73 return result;
74 }
75
GetUnreachableBlock(std::string block_id)76 std::string GetUnreachableBlock(std::string block_id) {
77 std::string result;
78 result += block_id + " = OpLabel\n";
79 result += "OpUnreachable\n";
80 return result;
81 }
82
GetBranchBlock(std::string block_id,std::string target_id)83 std::string GetBranchBlock(std::string block_id, std::string target_id) {
84 std::string result;
85 result += block_id + " = OpLabel\n";
86 result += "OpBranch " + target_id + "\n";
87 return result;
88 }
89
TEST_F(SplitInvalidUnreachableTest,NoInvalidBlocks)90 TEST_F(SplitInvalidUnreachableTest, NoInvalidBlocks) {
91 std::string input = spirv_header + function_head;
92 input += GetLoopMergeBlock("%9", "%10", "%11", "%12");
93 input += GetSelectionMergeBlock("%12", "%13", "%14", "%15", "%16");
94 input += GetReturnBlock("%15");
95 input += GetReturnBlock("%16");
96 input += GetUnreachableBlock("%10");
97 input += GetBranchBlock("%11", "%9");
98 input += GetUnreachableBlock("%14");
99 input += function_tail;
100
101 SinglePassRunAndCheck<SplitInvalidUnreachablePass>(input, input,
102 /* skip_nop = */ false);
103 }
104
TEST_F(SplitInvalidUnreachableTest,SelectionInLoop)105 TEST_F(SplitInvalidUnreachableTest, SelectionInLoop) {
106 std::string input = spirv_header + function_head;
107 input += GetLoopMergeBlock("%9", "%10", "%11", "%12");
108 input += GetSelectionMergeBlock("%12", "%13", "%11", "%15", "%16");
109 input += GetReturnBlock("%15");
110 input += GetReturnBlock("%16");
111 input += GetUnreachableBlock("%10");
112 input += GetBranchBlock("%11", "%9");
113 input += function_tail;
114
115 std::string expected = spirv_header + function_head;
116 expected += GetLoopMergeBlock("%9", "%10", "%11", "%12");
117 expected += GetSelectionMergeBlock("%12", "%13", "%16", "%14", "%15");
118 expected += GetReturnBlock("%14");
119 expected += GetReturnBlock("%15");
120 expected += GetUnreachableBlock("%10");
121 expected += GetUnreachableBlock("%16");
122 expected += GetBranchBlock("%11", "%9");
123 expected += function_tail;
124
125 SinglePassRunAndCheck<SplitInvalidUnreachablePass>(input, expected,
126 /* skip_nop = */ false);
127 }
128
TEST_F(SplitInvalidUnreachableTest,LoopInSelection)129 TEST_F(SplitInvalidUnreachableTest, LoopInSelection) {
130 std::string input = spirv_header + function_head;
131 input += GetSelectionMergeBlock("%9", "%10", "%11", "%12", "%13");
132 input += GetLoopMergeBlock("%12", "%14", "%11", "%15");
133 input += GetReturnBlock("%13");
134 input += GetUnreachableBlock("%14");
135 input += GetBranchBlock("%11", "%12");
136 input += GetReturnBlock("%15");
137 input += function_tail;
138
139 std::string expected = spirv_header + function_head;
140 expected += GetSelectionMergeBlock("%9", "%10", "%16", "%12", "%13");
141 expected += GetLoopMergeBlock("%12", "%14", "%11", "%15");
142 expected += GetReturnBlock("%13");
143 expected += GetUnreachableBlock("%14");
144 expected += GetUnreachableBlock("%16");
145 expected += GetBranchBlock("%11", "%12");
146 expected += GetReturnBlock("%15");
147 expected += function_tail;
148
149 SinglePassRunAndCheck<SplitInvalidUnreachablePass>(input, expected,
150 /* skip_nop = */ false);
151 }
152
153 } // namespace
154 } // namespace opt
155 } // namespace spvtools
156