1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <sstream>
16 #include <string>
17 #include <tuple>
18 
19 #include "gmock/gmock.h"
20 #include "test/unit_spirv.h"
21 #include "test/val/val_fixtures.h"
22 
23 namespace spvtools {
24 namespace val {
25 namespace {
26 
27 using ::testing::Combine;
28 using ::testing::HasSubstr;
29 using ::testing::Values;
30 using ::testing::ValuesIn;
31 
GenerateShaderCode(const std::string & body,const std::string & capabilities_and_extensions="",const std::string & execution_model="GLCompute")32 std::string GenerateShaderCode(
33     const std::string& body,
34     const std::string& capabilities_and_extensions = "",
35     const std::string& execution_model = "GLCompute") {
36   std::ostringstream ss;
37   ss << R"(
38 OpCapability Shader
39 OpCapability GroupNonUniform
40 OpCapability GroupNonUniformVote
41 OpCapability GroupNonUniformBallot
42 OpCapability GroupNonUniformShuffle
43 OpCapability GroupNonUniformShuffleRelative
44 OpCapability GroupNonUniformArithmetic
45 OpCapability GroupNonUniformClustered
46 OpCapability GroupNonUniformQuad
47 )";
48 
49   ss << capabilities_and_extensions;
50   ss << "OpMemoryModel Logical GLSL450\n";
51   ss << "OpEntryPoint " << execution_model << " %main \"main\"\n";
52   if (execution_model == "GLCompute") {
53     ss << "OpExecutionMode %main LocalSize 1 1 1\n";
54   }
55 
56   ss << R"(
57 %void = OpTypeVoid
58 %func = OpTypeFunction %void
59 %bool = OpTypeBool
60 %u32 = OpTypeInt 32 0
61 %int = OpTypeInt 32 1
62 %float = OpTypeFloat 32
63 %u32vec4 = OpTypeVector %u32 4
64 %u32vec3 = OpTypeVector %u32 3
65 
66 %true = OpConstantTrue %bool
67 %false = OpConstantFalse %bool
68 
69 %u32_0 = OpConstant %u32 0
70 
71 %float_0 = OpConstant %float 0
72 
73 %u32vec4_null = OpConstantComposite %u32vec4 %u32_0 %u32_0 %u32_0 %u32_0
74 %u32vec3_null = OpConstantComposite %u32vec3 %u32_0 %u32_0 %u32_0
75 
76 %cross_device = OpConstant %u32 0
77 %device = OpConstant %u32 1
78 %workgroup = OpConstant %u32 2
79 %subgroup = OpConstant %u32 3
80 %invocation = OpConstant %u32 4
81 
82 %reduce = OpConstant %u32 0
83 %inclusive_scan = OpConstant %u32 1
84 %exclusive_scan = OpConstant %u32 2
85 %clustered_reduce = OpConstant %u32 3
86 
87 %main = OpFunction %void None %func
88 %main_entry = OpLabel
89 )";
90 
91   ss << body;
92 
93   ss << R"(
94 OpReturn
95 OpFunctionEnd)";
96 
97   return ss.str();
98 }
99 
100 SpvScope scopes[] = {SpvScopeCrossDevice, SpvScopeDevice, SpvScopeWorkgroup,
101                      SpvScopeSubgroup, SpvScopeInvocation};
102 
103 using ValidateGroupNonUniform = spvtest::ValidateBase<bool>;
104 using GroupNonUniform = spvtest::ValidateBase<
105     std::tuple<std::string, std::string, SpvScope, std::string, std::string>>;
106 
ConvertScope(SpvScope scope)107 std::string ConvertScope(SpvScope scope) {
108   switch (scope) {
109     case SpvScopeCrossDevice:
110       return "%cross_device";
111     case SpvScopeDevice:
112       return "%device";
113     case SpvScopeWorkgroup:
114       return "%workgroup";
115     case SpvScopeSubgroup:
116       return "%subgroup";
117     case SpvScopeInvocation:
118       return "%invocation";
119     default:
120       return "";
121   }
122 }
123 
TEST_P(GroupNonUniform,Vulkan1p1)124 TEST_P(GroupNonUniform, Vulkan1p1) {
125   std::string opcode = std::get<0>(GetParam());
126   std::string type = std::get<1>(GetParam());
127   SpvScope execution_scope = std::get<2>(GetParam());
128   std::string args = std::get<3>(GetParam());
129   std::string error = std::get<4>(GetParam());
130 
131   std::ostringstream sstr;
132   sstr << "%result = " << opcode << " ";
133   sstr << type << " ";
134   sstr << ConvertScope(execution_scope) << " ";
135   sstr << args << "\n";
136 
137   CompileSuccessfully(GenerateShaderCode(sstr.str()), SPV_ENV_VULKAN_1_1);
138   spv_result_t result = ValidateInstructions(SPV_ENV_VULKAN_1_1);
139   if (error == "") {
140     if (execution_scope == SpvScopeSubgroup) {
141       EXPECT_EQ(SPV_SUCCESS, result);
142     } else {
143       EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
144       EXPECT_THAT(getDiagnosticString(),
145                   AnyVUID("VUID-StandaloneSpirv-None-04642"));
146       EXPECT_THAT(
147           getDiagnosticString(),
148           HasSubstr(
149               "in Vulkan environment Execution scope is limited to Subgroup"));
150     }
151   } else {
152     EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
153     EXPECT_THAT(getDiagnosticString(), HasSubstr(error));
154   }
155 }
156 
TEST_P(GroupNonUniform,Spirv1p3)157 TEST_P(GroupNonUniform, Spirv1p3) {
158   std::string opcode = std::get<0>(GetParam());
159   std::string type = std::get<1>(GetParam());
160   SpvScope execution_scope = std::get<2>(GetParam());
161   std::string args = std::get<3>(GetParam());
162   std::string error = std::get<4>(GetParam());
163 
164   std::ostringstream sstr;
165   sstr << "%result = " << opcode << " ";
166   sstr << type << " ";
167   sstr << ConvertScope(execution_scope) << " ";
168   sstr << args << "\n";
169 
170   CompileSuccessfully(GenerateShaderCode(sstr.str()), SPV_ENV_UNIVERSAL_1_3);
171   spv_result_t result = ValidateInstructions(SPV_ENV_UNIVERSAL_1_3);
172   if (error == "") {
173     if (execution_scope == SpvScopeSubgroup ||
174         execution_scope == SpvScopeWorkgroup) {
175       EXPECT_EQ(SPV_SUCCESS, result);
176     } else {
177       EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
178       EXPECT_THAT(
179           getDiagnosticString(),
180           HasSubstr("Execution scope is limited to Subgroup or Workgroup"));
181     }
182   } else {
183     EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
184     EXPECT_THAT(getDiagnosticString(), HasSubstr(error));
185   }
186 }
187 
188 INSTANTIATE_TEST_SUITE_P(GroupNonUniformElect, GroupNonUniform,
189                          Combine(Values("OpGroupNonUniformElect"),
190                                  Values("%bool"), ValuesIn(scopes), Values(""),
191                                  Values("")));
192 
193 INSTANTIATE_TEST_SUITE_P(GroupNonUniformVote, GroupNonUniform,
194                          Combine(Values("OpGroupNonUniformAll",
195                                         "OpGroupNonUniformAny",
196                                         "OpGroupNonUniformAllEqual"),
197                                  Values("%bool"), ValuesIn(scopes),
198                                  Values("%true"), Values("")));
199 
200 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBroadcast, GroupNonUniform,
201                          Combine(Values("OpGroupNonUniformBroadcast"),
202                                  Values("%bool"), ValuesIn(scopes),
203                                  Values("%true %u32_0"), Values("")));
204 
205 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBroadcastFirst, GroupNonUniform,
206                          Combine(Values("OpGroupNonUniformBroadcastFirst"),
207                                  Values("%bool"), ValuesIn(scopes),
208                                  Values("%true"), Values("")));
209 
210 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallot, GroupNonUniform,
211                          Combine(Values("OpGroupNonUniformBallot"),
212                                  Values("%u32vec4"), ValuesIn(scopes),
213                                  Values("%true"), Values("")));
214 
215 INSTANTIATE_TEST_SUITE_P(GroupNonUniformInverseBallot, GroupNonUniform,
216                          Combine(Values("OpGroupNonUniformInverseBallot"),
217                                  Values("%bool"), ValuesIn(scopes),
218                                  Values("%u32vec4_null"), Values("")));
219 
220 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotBitExtract, GroupNonUniform,
221                          Combine(Values("OpGroupNonUniformBallotBitExtract"),
222                                  Values("%bool"), ValuesIn(scopes),
223                                  Values("%u32vec4_null %u32_0"), Values("")));
224 
225 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotBitCount, GroupNonUniform,
226                          Combine(Values("OpGroupNonUniformBallotBitCount"),
227                                  Values("%u32"), ValuesIn(scopes),
228                                  Values("Reduce %u32vec4_null"), Values("")));
229 
230 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotFind, GroupNonUniform,
231                          Combine(Values("OpGroupNonUniformBallotFindLSB",
232                                         "OpGroupNonUniformBallotFindMSB"),
233                                  Values("%u32"), ValuesIn(scopes),
234                                  Values("%u32vec4_null"), Values("")));
235 
236 INSTANTIATE_TEST_SUITE_P(GroupNonUniformShuffle, GroupNonUniform,
237                          Combine(Values("OpGroupNonUniformShuffle",
238                                         "OpGroupNonUniformShuffleXor",
239                                         "OpGroupNonUniformShuffleUp",
240                                         "OpGroupNonUniformShuffleDown"),
241                                  Values("%u32"), ValuesIn(scopes),
242                                  Values("%u32_0 %u32_0"), Values("")));
243 
244 INSTANTIATE_TEST_SUITE_P(
245     GroupNonUniformIntegerArithmetic, GroupNonUniform,
246     Combine(Values("OpGroupNonUniformIAdd", "OpGroupNonUniformIMul",
247                    "OpGroupNonUniformSMin", "OpGroupNonUniformUMin",
248                    "OpGroupNonUniformSMax", "OpGroupNonUniformUMax",
249                    "OpGroupNonUniformBitwiseAnd", "OpGroupNonUniformBitwiseOr",
250                    "OpGroupNonUniformBitwiseXor"),
251             Values("%u32"), ValuesIn(scopes), Values("Reduce %u32_0"),
252             Values("")));
253 
254 INSTANTIATE_TEST_SUITE_P(
255     GroupNonUniformFloatArithmetic, GroupNonUniform,
256     Combine(Values("OpGroupNonUniformFAdd", "OpGroupNonUniformFMul",
257                    "OpGroupNonUniformFMin", "OpGroupNonUniformFMax"),
258             Values("%float"), ValuesIn(scopes), Values("Reduce %float_0"),
259             Values("")));
260 
261 INSTANTIATE_TEST_SUITE_P(GroupNonUniformLogicalArithmetic, GroupNonUniform,
262                          Combine(Values("OpGroupNonUniformLogicalAnd",
263                                         "OpGroupNonUniformLogicalOr",
264                                         "OpGroupNonUniformLogicalXor"),
265                                  Values("%bool"), ValuesIn(scopes),
266                                  Values("Reduce %true"), Values("")));
267 
268 INSTANTIATE_TEST_SUITE_P(GroupNonUniformQuad, GroupNonUniform,
269                          Combine(Values("OpGroupNonUniformQuadBroadcast",
270                                         "OpGroupNonUniformQuadSwap"),
271                                  Values("%u32"), ValuesIn(scopes),
272                                  Values("%u32_0 %u32_0"), Values("")));
273 
274 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotBitCountScope, GroupNonUniform,
275                          Combine(Values("OpGroupNonUniformBallotBitCount"),
276                                  Values("%u32"), ValuesIn(scopes),
277                                  Values("Reduce %u32vec4_null"), Values("")));
278 
279 INSTANTIATE_TEST_SUITE_P(
280     GroupNonUniformBallotBitCountBadResultType, GroupNonUniform,
281     Combine(
282         Values("OpGroupNonUniformBallotBitCount"), Values("%float", "%int"),
283         Values(SpvScopeSubgroup), Values("Reduce %u32vec4_null"),
284         Values("Expected Result Type to be an unsigned integer type scalar.")));
285 
286 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotBitCountBadValue, GroupNonUniform,
287                          Combine(Values("OpGroupNonUniformBallotBitCount"),
288                                  Values("%u32"), Values(SpvScopeSubgroup),
289                                  Values("Reduce %u32vec3_null", "Reduce %u32_0",
290                                         "Reduce %float_0"),
291                                  Values("Expected Value to be a vector of four "
292                                         "components of integer type scalar")));
293 
TEST_F(ValidateGroupNonUniform,VulkanGroupNonUniformBallotBitCountOperation)294 TEST_F(ValidateGroupNonUniform, VulkanGroupNonUniformBallotBitCountOperation) {
295   std::string test = R"(
296 OpCapability Shader
297 OpCapability GroupNonUniform
298 OpCapability GroupNonUniformBallot
299 OpCapability GroupNonUniformClustered
300 OpMemoryModel Logical GLSL450
301 OpEntryPoint GLCompute %main "main"
302 OpExecutionMode %main LocalSize 1 1 1
303 %void = OpTypeVoid
304 %func = OpTypeFunction %void
305 %u32 = OpTypeInt 32 0
306 %u32vec4 = OpTypeVector %u32 4
307 %u32_0 = OpConstant %u32 0
308 %u32vec4_null = OpConstantComposite %u32vec4 %u32_0 %u32_0 %u32_0 %u32_0
309 %subgroup = OpConstant %u32 3
310 %main = OpFunction %void None %func
311 %main_entry = OpLabel
312 %result = OpGroupNonUniformBallotBitCount %u32 %subgroup ClusteredReduce %u32vec4_null
313 OpReturn
314 OpFunctionEnd
315 )";
316 
317   CompileSuccessfully(test, SPV_ENV_VULKAN_1_1);
318   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_1));
319   EXPECT_THAT(
320       getDiagnosticString(),
321       AnyVUID("VUID-StandaloneSpirv-OpGroupNonUniformBallotBitCount-04685"));
322   EXPECT_THAT(
323       getDiagnosticString(),
324       HasSubstr(
325           "In Vulkan: The OpGroupNonUniformBallotBitCount group operation must "
326           "be only: Reduce, InclusiveScan, or ExclusiveScan."));
327 }
328 
329 }  // namespace
330 }  // namespace val
331 }  // namespace spvtools
332