1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <sstream>
16 #include <string>
17 #include <tuple>
18 
19 #include "gmock/gmock.h"
20 #include "test/unit_spirv.h"
21 #include "test/val/val_fixtures.h"
22 
23 namespace spvtools {
24 namespace val {
25 namespace {
26 
27 using ::testing::Combine;
28 using ::testing::HasSubstr;
29 using ::testing::Values;
30 using ::testing::ValuesIn;
31 
GenerateShaderCode(const std::string & body,const std::string & capabilities_and_extensions="",const std::string & execution_model="GLCompute")32 std::string GenerateShaderCode(
33     const std::string& body,
34     const std::string& capabilities_and_extensions = "",
35     const std::string& execution_model = "GLCompute") {
36   std::ostringstream ss;
37   ss << R"(
38 OpCapability Shader
39 OpCapability GroupNonUniform
40 OpCapability GroupNonUniformVote
41 OpCapability GroupNonUniformBallot
42 OpCapability GroupNonUniformShuffle
43 OpCapability GroupNonUniformShuffleRelative
44 OpCapability GroupNonUniformArithmetic
45 OpCapability GroupNonUniformClustered
46 OpCapability GroupNonUniformQuad
47 )";
48 
49   ss << capabilities_and_extensions;
50   ss << "OpMemoryModel Logical GLSL450\n";
51   ss << "OpEntryPoint " << execution_model << " %main \"main\"\n";
52   if (execution_model == "GLCompute") {
53     ss << "OpExecutionMode %main LocalSize 1 1 1\n";
54   }
55 
56   ss << R"(
57 %void = OpTypeVoid
58 %func = OpTypeFunction %void
59 %bool = OpTypeBool
60 %u32 = OpTypeInt 32 0
61 %int = OpTypeInt 32 1
62 %float = OpTypeFloat 32
63 %u32vec4 = OpTypeVector %u32 4
64 %u32vec3 = OpTypeVector %u32 3
65 
66 %true = OpConstantTrue %bool
67 %false = OpConstantFalse %bool
68 
69 %u32_0 = OpConstant %u32 0
70 
71 %float_0 = OpConstant %float 0
72 
73 %u32vec4_null = OpConstantComposite %u32vec4 %u32_0 %u32_0 %u32_0 %u32_0
74 %u32vec3_null = OpConstantComposite %u32vec3 %u32_0 %u32_0 %u32_0
75 
76 %cross_device = OpConstant %u32 0
77 %device = OpConstant %u32 1
78 %workgroup = OpConstant %u32 2
79 %subgroup = OpConstant %u32 3
80 %invocation = OpConstant %u32 4
81 
82 %reduce = OpConstant %u32 0
83 %inclusive_scan = OpConstant %u32 1
84 %exclusive_scan = OpConstant %u32 2
85 %clustered_reduce = OpConstant %u32 3
86 
87 %main = OpFunction %void None %func
88 %main_entry = OpLabel
89 )";
90 
91   ss << body;
92 
93   ss << R"(
94 OpReturn
95 OpFunctionEnd)";
96 
97   return ss.str();
98 }
99 
100 SpvScope scopes[] = {SpvScopeCrossDevice, SpvScopeDevice, SpvScopeWorkgroup,
101                      SpvScopeSubgroup, SpvScopeInvocation};
102 
103 using GroupNonUniform = spvtest::ValidateBase<
104     std::tuple<std::string, std::string, SpvScope, std::string, std::string>>;
105 
ConvertScope(SpvScope scope)106 std::string ConvertScope(SpvScope scope) {
107   switch (scope) {
108     case SpvScopeCrossDevice:
109       return "%cross_device";
110     case SpvScopeDevice:
111       return "%device";
112     case SpvScopeWorkgroup:
113       return "%workgroup";
114     case SpvScopeSubgroup:
115       return "%subgroup";
116     case SpvScopeInvocation:
117       return "%invocation";
118     default:
119       return "";
120   }
121 }
122 
TEST_P(GroupNonUniform,Vulkan1p1)123 TEST_P(GroupNonUniform, Vulkan1p1) {
124   std::string opcode = std::get<0>(GetParam());
125   std::string type = std::get<1>(GetParam());
126   SpvScope execution_scope = std::get<2>(GetParam());
127   std::string args = std::get<3>(GetParam());
128   std::string error = std::get<4>(GetParam());
129 
130   std::ostringstream sstr;
131   sstr << "%result = " << opcode << " ";
132   sstr << type << " ";
133   sstr << ConvertScope(execution_scope) << " ";
134   sstr << args << "\n";
135 
136   CompileSuccessfully(GenerateShaderCode(sstr.str()), SPV_ENV_VULKAN_1_1);
137   spv_result_t result = ValidateInstructions(SPV_ENV_VULKAN_1_1);
138   if (error == "") {
139     if (execution_scope == SpvScopeSubgroup) {
140       EXPECT_EQ(SPV_SUCCESS, result);
141     } else {
142       EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
143       EXPECT_THAT(
144           getDiagnosticString(),
145           HasSubstr(
146               "in Vulkan environment Execution scope is limited to Subgroup"));
147     }
148   } else {
149     EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
150     EXPECT_THAT(getDiagnosticString(), HasSubstr(error));
151   }
152 }
153 
TEST_P(GroupNonUniform,Spirv1p3)154 TEST_P(GroupNonUniform, Spirv1p3) {
155   std::string opcode = std::get<0>(GetParam());
156   std::string type = std::get<1>(GetParam());
157   SpvScope execution_scope = std::get<2>(GetParam());
158   std::string args = std::get<3>(GetParam());
159   std::string error = std::get<4>(GetParam());
160 
161   std::ostringstream sstr;
162   sstr << "%result = " << opcode << " ";
163   sstr << type << " ";
164   sstr << ConvertScope(execution_scope) << " ";
165   sstr << args << "\n";
166 
167   CompileSuccessfully(GenerateShaderCode(sstr.str()), SPV_ENV_UNIVERSAL_1_3);
168   spv_result_t result = ValidateInstructions(SPV_ENV_UNIVERSAL_1_3);
169   if (error == "") {
170     if (execution_scope == SpvScopeSubgroup ||
171         execution_scope == SpvScopeWorkgroup) {
172       EXPECT_EQ(SPV_SUCCESS, result);
173     } else {
174       EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
175       EXPECT_THAT(
176           getDiagnosticString(),
177           HasSubstr("Execution scope is limited to Subgroup or Workgroup"));
178     }
179   } else {
180     EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
181     EXPECT_THAT(getDiagnosticString(), HasSubstr(error));
182   }
183 }
184 
185 INSTANTIATE_TEST_SUITE_P(GroupNonUniformElect, GroupNonUniform,
186                          Combine(Values("OpGroupNonUniformElect"),
187                                  Values("%bool"), ValuesIn(scopes), Values(""),
188                                  Values("")));
189 
190 INSTANTIATE_TEST_SUITE_P(GroupNonUniformVote, GroupNonUniform,
191                          Combine(Values("OpGroupNonUniformAll",
192                                         "OpGroupNonUniformAny",
193                                         "OpGroupNonUniformAllEqual"),
194                                  Values("%bool"), ValuesIn(scopes),
195                                  Values("%true"), Values("")));
196 
197 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBroadcast, GroupNonUniform,
198                          Combine(Values("OpGroupNonUniformBroadcast"),
199                                  Values("%bool"), ValuesIn(scopes),
200                                  Values("%true %u32_0"), Values("")));
201 
202 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBroadcastFirst, GroupNonUniform,
203                          Combine(Values("OpGroupNonUniformBroadcastFirst"),
204                                  Values("%bool"), ValuesIn(scopes),
205                                  Values("%true"), Values("")));
206 
207 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallot, GroupNonUniform,
208                          Combine(Values("OpGroupNonUniformBallot"),
209                                  Values("%u32vec4"), ValuesIn(scopes),
210                                  Values("%true"), Values("")));
211 
212 INSTANTIATE_TEST_SUITE_P(GroupNonUniformInverseBallot, GroupNonUniform,
213                          Combine(Values("OpGroupNonUniformInverseBallot"),
214                                  Values("%bool"), ValuesIn(scopes),
215                                  Values("%u32vec4_null"), Values("")));
216 
217 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotBitExtract, GroupNonUniform,
218                          Combine(Values("OpGroupNonUniformBallotBitExtract"),
219                                  Values("%bool"), ValuesIn(scopes),
220                                  Values("%u32vec4_null %u32_0"), Values("")));
221 
222 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotBitCount, GroupNonUniform,
223                          Combine(Values("OpGroupNonUniformBallotBitCount"),
224                                  Values("%u32"), ValuesIn(scopes),
225                                  Values("Reduce %u32vec4_null"), Values("")));
226 
227 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotFind, GroupNonUniform,
228                          Combine(Values("OpGroupNonUniformBallotFindLSB",
229                                         "OpGroupNonUniformBallotFindMSB"),
230                                  Values("%u32"), ValuesIn(scopes),
231                                  Values("%u32vec4_null"), Values("")));
232 
233 INSTANTIATE_TEST_SUITE_P(GroupNonUniformShuffle, GroupNonUniform,
234                          Combine(Values("OpGroupNonUniformShuffle",
235                                         "OpGroupNonUniformShuffleXor",
236                                         "OpGroupNonUniformShuffleUp",
237                                         "OpGroupNonUniformShuffleDown"),
238                                  Values("%u32"), ValuesIn(scopes),
239                                  Values("%u32_0 %u32_0"), Values("")));
240 
241 INSTANTIATE_TEST_SUITE_P(
242     GroupNonUniformIntegerArithmetic, GroupNonUniform,
243     Combine(Values("OpGroupNonUniformIAdd", "OpGroupNonUniformIMul",
244                    "OpGroupNonUniformSMin", "OpGroupNonUniformUMin",
245                    "OpGroupNonUniformSMax", "OpGroupNonUniformUMax",
246                    "OpGroupNonUniformBitwiseAnd", "OpGroupNonUniformBitwiseOr",
247                    "OpGroupNonUniformBitwiseXor"),
248             Values("%u32"), ValuesIn(scopes), Values("Reduce %u32_0"),
249             Values("")));
250 
251 INSTANTIATE_TEST_SUITE_P(
252     GroupNonUniformFloatArithmetic, GroupNonUniform,
253     Combine(Values("OpGroupNonUniformFAdd", "OpGroupNonUniformFMul",
254                    "OpGroupNonUniformFMin", "OpGroupNonUniformFMax"),
255             Values("%float"), ValuesIn(scopes), Values("Reduce %float_0"),
256             Values("")));
257 
258 INSTANTIATE_TEST_SUITE_P(GroupNonUniformLogicalArithmetic, GroupNonUniform,
259                          Combine(Values("OpGroupNonUniformLogicalAnd",
260                                         "OpGroupNonUniformLogicalOr",
261                                         "OpGroupNonUniformLogicalXor"),
262                                  Values("%bool"), ValuesIn(scopes),
263                                  Values("Reduce %true"), Values("")));
264 
265 INSTANTIATE_TEST_SUITE_P(GroupNonUniformQuad, GroupNonUniform,
266                          Combine(Values("OpGroupNonUniformQuadBroadcast",
267                                         "OpGroupNonUniformQuadSwap"),
268                                  Values("%u32"), ValuesIn(scopes),
269                                  Values("%u32_0 %u32_0"), Values("")));
270 
271 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotBitCountScope, GroupNonUniform,
272                          Combine(Values("OpGroupNonUniformBallotBitCount"),
273                                  Values("%u32"), ValuesIn(scopes),
274                                  Values("Reduce %u32vec4_null"), Values("")));
275 
276 INSTANTIATE_TEST_SUITE_P(
277     GroupNonUniformBallotBitCountBadResultType, GroupNonUniform,
278     Combine(
279         Values("OpGroupNonUniformBallotBitCount"), Values("%float", "%int"),
280         Values(SpvScopeSubgroup), Values("Reduce %u32vec4_null"),
281         Values("Expected Result Type to be an unsigned integer type scalar.")));
282 
283 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotBitCountBadValue, GroupNonUniform,
284                          Combine(Values("OpGroupNonUniformBallotBitCount"),
285                                  Values("%u32"), Values(SpvScopeSubgroup),
286                                  Values("Reduce %u32vec3_null", "Reduce %u32_0",
287                                         "Reduce %float_0"),
288                                  Values("Expected Value to be a vector of four "
289                                         "components of integer type scalar")));
290 
291 }  // namespace
292 }  // namespace val
293 }  // namespace spvtools
294