1 // Copyright 2019 The SwiftShader Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "SpirvShader.hpp"
16
17 #include <spirv/unified1/spirv.hpp>
18
19 namespace sw {
20
21 struct SpirvShader::Impl::Group
22 {
23 // Template function to perform a binary operation.
24 // |TYPE| should be the type of the binary operation (as a SIMD::<ScalarType>).
25 // |I| should be a type suitable to initialize the identity value.
26 // |APPLY| should be a callable object that takes two RValue<TYPE> parameters
27 // and returns a new RValue<TYPE> corresponding to the operation's result.
28 template<typename TYPE, typename I, typename APPLY>
BinaryOperationsw::SpirvShader::Impl::Group29 static void BinaryOperation(
30 const SpirvShader *shader,
31 const SpirvShader::InsnIterator &insn,
32 const SpirvShader::EmitState *state,
33 Intermediate &dst,
34 const I identityValue,
35 APPLY &&apply)
36 {
37 SpirvShader::GenericValue value(shader, state, insn.word(5));
38 auto &type = shader->getType(SpirvShader::Type::ID(insn.word(1)));
39 for(auto i = 0u; i < type.sizeInComponents; i++)
40 {
41 auto mask = As<SIMD::UInt>(state->activeLaneMask());
42 auto identity = TYPE(identityValue);
43 SIMD::UInt v_uint = (value.UInt(i) & mask) | (As<SIMD::UInt>(identity) & ~mask);
44 TYPE v = As<TYPE>(v_uint);
45 switch(spv::GroupOperation(insn.word(4)))
46 {
47 case spv::GroupOperationReduce:
48 {
49 // NOTE: floating-point add and multiply are not really commutative so
50 // ensure that all values in the final lanes are identical
51 TYPE v2 = apply(v.xxzz, v.yyww); // [xy] [xy] [zw] [zw]
52 TYPE v3 = apply(v2.xxxx, v2.zzzz); // [xyzw] [xyzw] [xyzw] [xyzw]
53 dst.move(i, v3);
54 break;
55 }
56 case spv::GroupOperationInclusiveScan:
57 {
58 TYPE v2 = apply(v, Shuffle(v, identity, 0x4012) /* [id, v.y, v.z, v.w] */); // [x] [xy] [yz] [zw]
59 TYPE v3 = apply(v2, Shuffle(v2, identity, 0x4401) /* [id, id, v2.x, v2.y] */); // [x] [xy] [xyz] [xyzw]
60 dst.move(i, v3);
61 break;
62 }
63 case spv::GroupOperationExclusiveScan:
64 {
65 TYPE v2 = apply(v, Shuffle(v, identity, 0x4012) /* [id, v.y, v.z, v.w] */); // [x] [xy] [yz] [zw]
66 TYPE v3 = apply(v2, Shuffle(v2, identity, 0x4401) /* [id, id, v2.x, v2.y] */); // [x] [xy] [xyz] [xyzw]
67 auto v4 = Shuffle(v3, identity, 0x4012 /* [id, v3.x, v3.y, v3.z] */); // [i] [x] [xy] [xyz]
68 dst.move(i, v4);
69 break;
70 }
71 default:
72 UNSUPPORTED("EmitGroupNonUniform op: %s Group operation: %d",
73 SpirvShader::OpcodeName(type.opcode()).c_str(), insn.word(4));
74 }
75 }
76 }
77 };
78
EmitGroupNonUniform(InsnIterator insn,EmitState * state) const79 SpirvShader::EmitResult SpirvShader::EmitGroupNonUniform(InsnIterator insn, EmitState *state) const
80 {
81 static_assert(SIMD::Width == 4, "EmitGroupNonUniform makes many assumptions that the SIMD vector width is 4");
82
83 auto &type = getType(Type::ID(insn.word(1)));
84 Object::ID resultId = insn.word(2);
85 auto scope = spv::Scope(GetConstScalarInt(insn.word(3)));
86 ASSERT_MSG(scope == spv::ScopeSubgroup, "Scope for Non Uniform Group Operations must be Subgroup for Vulkan 1.1");
87
88 auto &dst = state->createIntermediate(resultId, type.sizeInComponents);
89
90 switch(insn.opcode())
91 {
92 case spv::OpGroupNonUniformElect:
93 {
94 // Result is true only in the active invocation with the lowest id
95 // in the group, otherwise result is false.
96 SIMD::Int active = state->activeLaneMask();
97 // TODO: Would be nice if we could write this as:
98 // elect = active & ~(active.Oxyz | active.OOxy | active.OOOx)
99 auto v0111 = SIMD::Int(0, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF);
100 auto elect = active & ~(v0111 & (active.xxyz | active.xxxy | active.xxxx));
101 dst.move(0, elect);
102 break;
103 }
104
105 case spv::OpGroupNonUniformAll:
106 {
107 GenericValue predicate(this, state, insn.word(4));
108 dst.move(0, AndAll(predicate.UInt(0) | ~As<SIMD::UInt>(state->activeLaneMask())));
109 break;
110 }
111
112 case spv::OpGroupNonUniformAny:
113 {
114 GenericValue predicate(this, state, insn.word(4));
115 dst.move(0, OrAll(predicate.UInt(0) & As<SIMD::UInt>(state->activeLaneMask())));
116 break;
117 }
118
119 case spv::OpGroupNonUniformAllEqual:
120 {
121 GenericValue value(this, state, insn.word(4));
122 auto res = SIMD::UInt(0xffffffff);
123 SIMD::UInt active = As<SIMD::UInt>(state->activeLaneMask());
124 SIMD::UInt inactive = ~active;
125 for(auto i = 0u; i < type.sizeInComponents; i++)
126 {
127 SIMD::UInt v = value.UInt(i) & active;
128 SIMD::UInt filled = v;
129 for(int j = 0; j < SIMD::Width - 1; j++)
130 {
131 filled |= filled.yzwx & inactive; // Populate inactive 'holes' with a live value
132 }
133 res &= AndAll(CmpEQ(filled.xyzw, filled.yzwx));
134 }
135 dst.move(0, res);
136 break;
137 }
138
139 case spv::OpGroupNonUniformBroadcast:
140 {
141 auto valueId = Object::ID(insn.word(4));
142 auto id = SIMD::Int(GetConstScalarInt(insn.word(5)));
143 GenericValue value(this, state, valueId);
144 auto mask = CmpEQ(id, SIMD::Int(0, 1, 2, 3));
145 for(auto i = 0u; i < type.sizeInComponents; i++)
146 {
147 dst.move(i, OrAll(value.Int(i) & mask));
148 }
149 break;
150 }
151
152 case spv::OpGroupNonUniformBroadcastFirst:
153 {
154 auto valueId = Object::ID(insn.word(4));
155 GenericValue value(this, state, valueId);
156 // Result is true only in the active invocation with the lowest id
157 // in the group, otherwise result is false.
158 SIMD::Int active = state->activeLaneMask();
159 // TODO: Would be nice if we could write this as:
160 // elect = active & ~(active.Oxyz | active.OOxy | active.OOOx)
161 auto v0111 = SIMD::Int(0, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF);
162 auto elect = active & ~(v0111 & (active.xxyz | active.xxxy | active.xxxx));
163 for(auto i = 0u; i < type.sizeInComponents; i++)
164 {
165 dst.move(i, OrAll(value.Int(i) & elect));
166 }
167 break;
168 }
169
170 case spv::OpGroupNonUniformBallot:
171 {
172 ASSERT(type.sizeInComponents == 4);
173 GenericValue predicate(this, state, insn.word(4));
174 dst.move(0, SIMD::Int(SignMask(state->activeLaneMask() & predicate.Int(0))));
175 dst.move(1, SIMD::Int(0));
176 dst.move(2, SIMD::Int(0));
177 dst.move(3, SIMD::Int(0));
178 break;
179 }
180
181 case spv::OpGroupNonUniformInverseBallot:
182 {
183 auto valueId = Object::ID(insn.word(4));
184 ASSERT(type.sizeInComponents == 1);
185 ASSERT(getType(getObject(valueId).type).sizeInComponents == 4);
186 GenericValue value(this, state, valueId);
187 auto bit = (value.Int(0) >> SIMD::Int(0, 1, 2, 3)) & SIMD::Int(1);
188 dst.move(0, -bit);
189 break;
190 }
191
192 case spv::OpGroupNonUniformBallotBitExtract:
193 {
194 auto valueId = Object::ID(insn.word(4));
195 auto indexId = Object::ID(insn.word(5));
196 ASSERT(type.sizeInComponents == 1);
197 ASSERT(getType(getObject(valueId).type).sizeInComponents == 4);
198 ASSERT(getType(getObject(indexId).type).sizeInComponents == 1);
199 GenericValue value(this, state, valueId);
200 GenericValue index(this, state, indexId);
201 auto vecIdx = index.Int(0) / SIMD::Int(32);
202 auto bitIdx = index.Int(0) & SIMD::Int(31);
203 auto bits = (value.Int(0) & CmpEQ(vecIdx, SIMD::Int(0))) |
204 (value.Int(1) & CmpEQ(vecIdx, SIMD::Int(1))) |
205 (value.Int(2) & CmpEQ(vecIdx, SIMD::Int(2))) |
206 (value.Int(3) & CmpEQ(vecIdx, SIMD::Int(3)));
207 dst.move(0, -((bits >> bitIdx) & SIMD::Int(1)));
208 break;
209 }
210
211 case spv::OpGroupNonUniformBallotBitCount:
212 {
213 auto operation = spv::GroupOperation(insn.word(4));
214 auto valueId = Object::ID(insn.word(5));
215 ASSERT(type.sizeInComponents == 1);
216 ASSERT(getType(getObject(valueId).type).sizeInComponents == 4);
217 GenericValue value(this, state, valueId);
218 switch(operation)
219 {
220 case spv::GroupOperationReduce:
221 dst.move(0, CountBits(value.UInt(0) & SIMD::UInt(15)));
222 break;
223 case spv::GroupOperationInclusiveScan:
224 dst.move(0, CountBits(value.UInt(0) & SIMD::UInt(1, 3, 7, 15)));
225 break;
226 case spv::GroupOperationExclusiveScan:
227 dst.move(0, CountBits(value.UInt(0) & SIMD::UInt(0, 1, 3, 7)));
228 break;
229 default:
230 UNSUPPORTED("GroupOperation %d", int(operation));
231 }
232 break;
233 }
234
235 case spv::OpGroupNonUniformBallotFindLSB:
236 {
237 auto valueId = Object::ID(insn.word(4));
238 ASSERT(type.sizeInComponents == 1);
239 ASSERT(getType(getObject(valueId).type).sizeInComponents == 4);
240 GenericValue value(this, state, valueId);
241 dst.move(0, Cttz(value.UInt(0) & SIMD::UInt(15), true));
242 break;
243 }
244
245 case spv::OpGroupNonUniformBallotFindMSB:
246 {
247 auto valueId = Object::ID(insn.word(4));
248 ASSERT(type.sizeInComponents == 1);
249 ASSERT(getType(getObject(valueId).type).sizeInComponents == 4);
250 GenericValue value(this, state, valueId);
251 dst.move(0, SIMD::UInt(31) - Ctlz(value.UInt(0) & SIMD::UInt(15), false));
252 break;
253 }
254
255 case spv::OpGroupNonUniformShuffle:
256 {
257 GenericValue value(this, state, insn.word(4));
258 GenericValue id(this, state, insn.word(5));
259 auto x = CmpEQ(SIMD::Int(0), id.Int(0));
260 auto y = CmpEQ(SIMD::Int(1), id.Int(0));
261 auto z = CmpEQ(SIMD::Int(2), id.Int(0));
262 auto w = CmpEQ(SIMD::Int(3), id.Int(0));
263 for(auto i = 0u; i < type.sizeInComponents; i++)
264 {
265 SIMD::Int v = value.Int(i);
266 dst.move(i, (x & v.xxxx) | (y & v.yyyy) | (z & v.zzzz) | (w & v.wwww));
267 }
268 break;
269 }
270
271 case spv::OpGroupNonUniformShuffleXor:
272 {
273 GenericValue value(this, state, insn.word(4));
274 GenericValue mask(this, state, insn.word(5));
275 auto x = CmpEQ(SIMD::Int(0), SIMD::Int(0, 1, 2, 3) ^ mask.Int(0));
276 auto y = CmpEQ(SIMD::Int(1), SIMD::Int(0, 1, 2, 3) ^ mask.Int(0));
277 auto z = CmpEQ(SIMD::Int(2), SIMD::Int(0, 1, 2, 3) ^ mask.Int(0));
278 auto w = CmpEQ(SIMD::Int(3), SIMD::Int(0, 1, 2, 3) ^ mask.Int(0));
279 for(auto i = 0u; i < type.sizeInComponents; i++)
280 {
281 SIMD::Int v = value.Int(i);
282 dst.move(i, (x & v.xxxx) | (y & v.yyyy) | (z & v.zzzz) | (w & v.wwww));
283 }
284 break;
285 }
286
287 case spv::OpGroupNonUniformShuffleUp:
288 {
289 GenericValue value(this, state, insn.word(4));
290 GenericValue delta(this, state, insn.word(5));
291 auto d0 = CmpEQ(SIMD::Int(0), delta.Int(0));
292 auto d1 = CmpEQ(SIMD::Int(1), delta.Int(0));
293 auto d2 = CmpEQ(SIMD::Int(2), delta.Int(0));
294 auto d3 = CmpEQ(SIMD::Int(3), delta.Int(0));
295 for(auto i = 0u; i < type.sizeInComponents; i++)
296 {
297 SIMD::Int v = value.Int(i);
298 dst.move(i, (d0 & v.xyzw) | (d1 & v.xxyz) | (d2 & v.xxxy) | (d3 & v.xxxx));
299 }
300 break;
301 }
302
303 case spv::OpGroupNonUniformShuffleDown:
304 {
305 GenericValue value(this, state, insn.word(4));
306 GenericValue delta(this, state, insn.word(5));
307 auto d0 = CmpEQ(SIMD::Int(0), delta.Int(0));
308 auto d1 = CmpEQ(SIMD::Int(1), delta.Int(0));
309 auto d2 = CmpEQ(SIMD::Int(2), delta.Int(0));
310 auto d3 = CmpEQ(SIMD::Int(3), delta.Int(0));
311 for(auto i = 0u; i < type.sizeInComponents; i++)
312 {
313 SIMD::Int v = value.Int(i);
314 dst.move(i, (d0 & v.xyzw) | (d1 & v.yzww) | (d2 & v.zwww) | (d3 & v.wwww));
315 }
316 break;
317 }
318
319 case spv::OpGroupNonUniformIAdd:
320 Impl::Group::BinaryOperation<SIMD::Int>(
321 this, insn, state, dst, 0,
322 [](auto a, auto b) { return a + b; });
323 break;
324
325 case spv::OpGroupNonUniformFAdd:
326 Impl::Group::BinaryOperation<SIMD::Float>(
327 this, insn, state, dst, 0.0f,
328 [](auto a, auto b) { return a + b; });
329 break;
330
331 case spv::OpGroupNonUniformIMul:
332 Impl::Group::BinaryOperation<SIMD::Int>(
333 this, insn, state, dst, 1,
334 [](auto a, auto b) { return a * b; });
335 break;
336
337 case spv::OpGroupNonUniformFMul:
338 Impl::Group::BinaryOperation<SIMD::Float>(
339 this, insn, state, dst, 1.0f,
340 [](auto a, auto b) { return a * b; });
341 break;
342
343 case spv::OpGroupNonUniformBitwiseAnd:
344 Impl::Group::BinaryOperation<SIMD::UInt>(
345 this, insn, state, dst, ~0u,
346 [](auto a, auto b) { return a & b; });
347 break;
348
349 case spv::OpGroupNonUniformBitwiseOr:
350 Impl::Group::BinaryOperation<SIMD::UInt>(
351 this, insn, state, dst, 0,
352 [](auto a, auto b) { return a | b; });
353 break;
354
355 case spv::OpGroupNonUniformBitwiseXor:
356 Impl::Group::BinaryOperation<SIMD::UInt>(
357 this, insn, state, dst, 0,
358 [](auto a, auto b) { return a ^ b; });
359 break;
360
361 case spv::OpGroupNonUniformSMin:
362 Impl::Group::BinaryOperation<SIMD::Int>(
363 this, insn, state, dst, INT32_MAX,
364 [](auto a, auto b) { return Min(a, b); });
365 break;
366
367 case spv::OpGroupNonUniformUMin:
368 Impl::Group::BinaryOperation<SIMD::UInt>(
369 this, insn, state, dst, ~0u,
370 [](auto a, auto b) { return Min(a, b); });
371 break;
372
373 case spv::OpGroupNonUniformFMin:
374 Impl::Group::BinaryOperation<SIMD::Float>(
375 this, insn, state, dst, SIMD::Float::infinity(),
376 [](auto a, auto b) { return NMin(a, b); });
377 break;
378
379 case spv::OpGroupNonUniformSMax:
380 Impl::Group::BinaryOperation<SIMD::Int>(
381 this, insn, state, dst, INT32_MIN,
382 [](auto a, auto b) { return Max(a, b); });
383 break;
384
385 case spv::OpGroupNonUniformUMax:
386 Impl::Group::BinaryOperation<SIMD::UInt>(
387 this, insn, state, dst, 0,
388 [](auto a, auto b) { return Max(a, b); });
389 break;
390
391 case spv::OpGroupNonUniformFMax:
392 Impl::Group::BinaryOperation<SIMD::Float>(
393 this, insn, state, dst, -SIMD::Float::infinity(),
394 [](auto a, auto b) { return NMax(a, b); });
395 break;
396
397 case spv::OpGroupNonUniformLogicalAnd:
398 Impl::Group::BinaryOperation<SIMD::UInt>(
399 this, insn, state, dst, ~0u,
400 [](auto a, auto b) {
401 SIMD::UInt zero = SIMD::UInt(0);
402 return CmpNEQ(a, zero) & CmpNEQ(b, zero);
403 });
404 break;
405
406 case spv::OpGroupNonUniformLogicalOr:
407 Impl::Group::BinaryOperation<SIMD::UInt>(
408 this, insn, state, dst, 0,
409 [](auto a, auto b) {
410 SIMD::UInt zero = SIMD::UInt(0);
411 return CmpNEQ(a, zero) | CmpNEQ(b, zero);
412 });
413 break;
414
415 case spv::OpGroupNonUniformLogicalXor:
416 Impl::Group::BinaryOperation<SIMD::UInt>(
417 this, insn, state, dst, 0,
418 [](auto a, auto b) {
419 SIMD::UInt zero = SIMD::UInt(0);
420 return CmpNEQ(a, zero) ^ CmpNEQ(b, zero);
421 });
422 break;
423
424 default:
425 UNSUPPORTED("EmitGroupNonUniform op: %s", OpcodeName(type.opcode()).c_str());
426 }
427 return EmitResult::Continue;
428 }
429
430 } // namespace sw
431