1 // Copyright 2019 The SwiftShader Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //    http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "SpirvShader.hpp"
16 
17 #include <spirv/unified1/spirv.hpp>
18 
19 namespace sw {
20 
21 struct SpirvShader::Impl::Group
22 {
23 	// Template function to perform a binary operation.
24 	// |TYPE| should be the type of the binary operation (as a SIMD::<ScalarType>).
25 	// |I| should be a type suitable to initialize the identity value.
26 	// |APPLY| should be a callable object that takes two RValue<TYPE> parameters
27 	// and returns a new RValue<TYPE> corresponding to the operation's result.
28 	template<typename TYPE, typename I, typename APPLY>
BinaryOperationsw::SpirvShader::Impl::Group29 	static void BinaryOperation(
30 	    const SpirvShader *shader,
31 	    const SpirvShader::InsnIterator &insn,
32 	    const SpirvShader::EmitState *state,
33 	    Intermediate &dst,
34 	    const I identityValue,
35 	    APPLY &&apply)
36 	{
37 		SpirvShader::GenericValue value(shader, state, insn.word(5));
38 		auto &type = shader->getType(SpirvShader::Type::ID(insn.word(1)));
39 		for(auto i = 0u; i < type.sizeInComponents; i++)
40 		{
41 			auto mask = As<SIMD::UInt>(state->activeLaneMask());
42 			auto identity = TYPE(identityValue);
43 			SIMD::UInt v_uint = (value.UInt(i) & mask) | (As<SIMD::UInt>(identity) & ~mask);
44 			TYPE v = As<TYPE>(v_uint);
45 			switch(spv::GroupOperation(insn.word(4)))
46 			{
47 				case spv::GroupOperationReduce:
48 				{
49 					// NOTE: floating-point add and multiply are not really commutative so
50 					//       ensure that all values in the final lanes are identical
51 					TYPE v2 = apply(v.xxzz, v.yyww);    // [xy]   [xy]   [zw]   [zw]
52 					TYPE v3 = apply(v2.xxxx, v2.zzzz);  // [xyzw] [xyzw] [xyzw] [xyzw]
53 					dst.move(i, v3);
54 					break;
55 				}
56 				case spv::GroupOperationInclusiveScan:
57 				{
58 					TYPE v2 = apply(v, Shuffle(v, identity, 0x4012) /* [id, v.y, v.z, v.w] */);      // [x] [xy] [yz]  [zw]
59 					TYPE v3 = apply(v2, Shuffle(v2, identity, 0x4401) /* [id,  id, v2.x, v2.y] */);  // [x] [xy] [xyz] [xyzw]
60 					dst.move(i, v3);
61 					break;
62 				}
63 				case spv::GroupOperationExclusiveScan:
64 				{
65 					TYPE v2 = apply(v, Shuffle(v, identity, 0x4012) /* [id, v.y, v.z, v.w] */);      // [x] [xy] [yz]  [zw]
66 					TYPE v3 = apply(v2, Shuffle(v2, identity, 0x4401) /* [id,  id, v2.x, v2.y] */);  // [x] [xy] [xyz] [xyzw]
67 					auto v4 = Shuffle(v3, identity, 0x4012 /* [id, v3.x, v3.y, v3.z] */);            // [i] [x]  [xy]  [xyz]
68 					dst.move(i, v4);
69 					break;
70 				}
71 				default:
72 					UNSUPPORTED("EmitGroupNonUniform op: %s Group operation: %d",
73 					            SpirvShader::OpcodeName(type.opcode()).c_str(), insn.word(4));
74 			}
75 		}
76 	}
77 };
78 
EmitGroupNonUniform(InsnIterator insn,EmitState * state) const79 SpirvShader::EmitResult SpirvShader::EmitGroupNonUniform(InsnIterator insn, EmitState *state) const
80 {
81 	static_assert(SIMD::Width == 4, "EmitGroupNonUniform makes many assumptions that the SIMD vector width is 4");
82 
83 	auto &type = getType(Type::ID(insn.word(1)));
84 	Object::ID resultId = insn.word(2);
85 	auto scope = spv::Scope(GetConstScalarInt(insn.word(3)));
86 	ASSERT_MSG(scope == spv::ScopeSubgroup, "Scope for Non Uniform Group Operations must be Subgroup for Vulkan 1.1");
87 
88 	auto &dst = state->createIntermediate(resultId, type.sizeInComponents);
89 
90 	switch(insn.opcode())
91 	{
92 		case spv::OpGroupNonUniformElect:
93 		{
94 			// Result is true only in the active invocation with the lowest id
95 			// in the group, otherwise result is false.
96 			SIMD::Int active = state->activeLaneMask();
97 			// TODO: Would be nice if we could write this as:
98 			//   elect = active & ~(active.Oxyz | active.OOxy | active.OOOx)
99 			auto v0111 = SIMD::Int(0, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF);
100 			auto elect = active & ~(v0111 & (active.xxyz | active.xxxy | active.xxxx));
101 			dst.move(0, elect);
102 			break;
103 		}
104 
105 		case spv::OpGroupNonUniformAll:
106 		{
107 			GenericValue predicate(this, state, insn.word(4));
108 			dst.move(0, AndAll(predicate.UInt(0) | ~As<SIMD::UInt>(state->activeLaneMask())));
109 			break;
110 		}
111 
112 		case spv::OpGroupNonUniformAny:
113 		{
114 			GenericValue predicate(this, state, insn.word(4));
115 			dst.move(0, OrAll(predicate.UInt(0) & As<SIMD::UInt>(state->activeLaneMask())));
116 			break;
117 		}
118 
119 		case spv::OpGroupNonUniformAllEqual:
120 		{
121 			GenericValue value(this, state, insn.word(4));
122 			auto res = SIMD::UInt(0xffffffff);
123 			SIMD::UInt active = As<SIMD::UInt>(state->activeLaneMask());
124 			SIMD::UInt inactive = ~active;
125 			for(auto i = 0u; i < type.sizeInComponents; i++)
126 			{
127 				SIMD::UInt v = value.UInt(i) & active;
128 				SIMD::UInt filled = v;
129 				for(int j = 0; j < SIMD::Width - 1; j++)
130 				{
131 					filled |= filled.yzwx & inactive;  // Populate inactive 'holes' with a live value
132 				}
133 				res &= AndAll(CmpEQ(filled.xyzw, filled.yzwx));
134 			}
135 			dst.move(0, res);
136 			break;
137 		}
138 
139 		case spv::OpGroupNonUniformBroadcast:
140 		{
141 			auto valueId = Object::ID(insn.word(4));
142 			auto id = SIMD::Int(GetConstScalarInt(insn.word(5)));
143 			GenericValue value(this, state, valueId);
144 			auto mask = CmpEQ(id, SIMD::Int(0, 1, 2, 3));
145 			for(auto i = 0u; i < type.sizeInComponents; i++)
146 			{
147 				dst.move(i, OrAll(value.Int(i) & mask));
148 			}
149 			break;
150 		}
151 
152 		case spv::OpGroupNonUniformBroadcastFirst:
153 		{
154 			auto valueId = Object::ID(insn.word(4));
155 			GenericValue value(this, state, valueId);
156 			// Result is true only in the active invocation with the lowest id
157 			// in the group, otherwise result is false.
158 			SIMD::Int active = state->activeLaneMask();
159 			// TODO: Would be nice if we could write this as:
160 			//   elect = active & ~(active.Oxyz | active.OOxy | active.OOOx)
161 			auto v0111 = SIMD::Int(0, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF);
162 			auto elect = active & ~(v0111 & (active.xxyz | active.xxxy | active.xxxx));
163 			for(auto i = 0u; i < type.sizeInComponents; i++)
164 			{
165 				dst.move(i, OrAll(value.Int(i) & elect));
166 			}
167 			break;
168 		}
169 
170 		case spv::OpGroupNonUniformBallot:
171 		{
172 			ASSERT(type.sizeInComponents == 4);
173 			GenericValue predicate(this, state, insn.word(4));
174 			dst.move(0, SIMD::Int(SignMask(state->activeLaneMask() & predicate.Int(0))));
175 			dst.move(1, SIMD::Int(0));
176 			dst.move(2, SIMD::Int(0));
177 			dst.move(3, SIMD::Int(0));
178 			break;
179 		}
180 
181 		case spv::OpGroupNonUniformInverseBallot:
182 		{
183 			auto valueId = Object::ID(insn.word(4));
184 			ASSERT(type.sizeInComponents == 1);
185 			ASSERT(getType(getObject(valueId).type).sizeInComponents == 4);
186 			GenericValue value(this, state, valueId);
187 			auto bit = (value.Int(0) >> SIMD::Int(0, 1, 2, 3)) & SIMD::Int(1);
188 			dst.move(0, -bit);
189 			break;
190 		}
191 
192 		case spv::OpGroupNonUniformBallotBitExtract:
193 		{
194 			auto valueId = Object::ID(insn.word(4));
195 			auto indexId = Object::ID(insn.word(5));
196 			ASSERT(type.sizeInComponents == 1);
197 			ASSERT(getType(getObject(valueId).type).sizeInComponents == 4);
198 			ASSERT(getType(getObject(indexId).type).sizeInComponents == 1);
199 			GenericValue value(this, state, valueId);
200 			GenericValue index(this, state, indexId);
201 			auto vecIdx = index.Int(0) / SIMD::Int(32);
202 			auto bitIdx = index.Int(0) & SIMD::Int(31);
203 			auto bits = (value.Int(0) & CmpEQ(vecIdx, SIMD::Int(0))) |
204 			            (value.Int(1) & CmpEQ(vecIdx, SIMD::Int(1))) |
205 			            (value.Int(2) & CmpEQ(vecIdx, SIMD::Int(2))) |
206 			            (value.Int(3) & CmpEQ(vecIdx, SIMD::Int(3)));
207 			dst.move(0, -((bits >> bitIdx) & SIMD::Int(1)));
208 			break;
209 		}
210 
211 		case spv::OpGroupNonUniformBallotBitCount:
212 		{
213 			auto operation = spv::GroupOperation(insn.word(4));
214 			auto valueId = Object::ID(insn.word(5));
215 			ASSERT(type.sizeInComponents == 1);
216 			ASSERT(getType(getObject(valueId).type).sizeInComponents == 4);
217 			GenericValue value(this, state, valueId);
218 			switch(operation)
219 			{
220 				case spv::GroupOperationReduce:
221 					dst.move(0, CountBits(value.UInt(0) & SIMD::UInt(15)));
222 					break;
223 				case spv::GroupOperationInclusiveScan:
224 					dst.move(0, CountBits(value.UInt(0) & SIMD::UInt(1, 3, 7, 15)));
225 					break;
226 				case spv::GroupOperationExclusiveScan:
227 					dst.move(0, CountBits(value.UInt(0) & SIMD::UInt(0, 1, 3, 7)));
228 					break;
229 				default:
230 					UNSUPPORTED("GroupOperation %d", int(operation));
231 			}
232 			break;
233 		}
234 
235 		case spv::OpGroupNonUniformBallotFindLSB:
236 		{
237 			auto valueId = Object::ID(insn.word(4));
238 			ASSERT(type.sizeInComponents == 1);
239 			ASSERT(getType(getObject(valueId).type).sizeInComponents == 4);
240 			GenericValue value(this, state, valueId);
241 			dst.move(0, Cttz(value.UInt(0) & SIMD::UInt(15), true));
242 			break;
243 		}
244 
245 		case spv::OpGroupNonUniformBallotFindMSB:
246 		{
247 			auto valueId = Object::ID(insn.word(4));
248 			ASSERT(type.sizeInComponents == 1);
249 			ASSERT(getType(getObject(valueId).type).sizeInComponents == 4);
250 			GenericValue value(this, state, valueId);
251 			dst.move(0, SIMD::UInt(31) - Ctlz(value.UInt(0) & SIMD::UInt(15), false));
252 			break;
253 		}
254 
255 		case spv::OpGroupNonUniformShuffle:
256 		{
257 			GenericValue value(this, state, insn.word(4));
258 			GenericValue id(this, state, insn.word(5));
259 			auto x = CmpEQ(SIMD::Int(0), id.Int(0));
260 			auto y = CmpEQ(SIMD::Int(1), id.Int(0));
261 			auto z = CmpEQ(SIMD::Int(2), id.Int(0));
262 			auto w = CmpEQ(SIMD::Int(3), id.Int(0));
263 			for(auto i = 0u; i < type.sizeInComponents; i++)
264 			{
265 				SIMD::Int v = value.Int(i);
266 				dst.move(i, (x & v.xxxx) | (y & v.yyyy) | (z & v.zzzz) | (w & v.wwww));
267 			}
268 			break;
269 		}
270 
271 		case spv::OpGroupNonUniformShuffleXor:
272 		{
273 			GenericValue value(this, state, insn.word(4));
274 			GenericValue mask(this, state, insn.word(5));
275 			auto x = CmpEQ(SIMD::Int(0), SIMD::Int(0, 1, 2, 3) ^ mask.Int(0));
276 			auto y = CmpEQ(SIMD::Int(1), SIMD::Int(0, 1, 2, 3) ^ mask.Int(0));
277 			auto z = CmpEQ(SIMD::Int(2), SIMD::Int(0, 1, 2, 3) ^ mask.Int(0));
278 			auto w = CmpEQ(SIMD::Int(3), SIMD::Int(0, 1, 2, 3) ^ mask.Int(0));
279 			for(auto i = 0u; i < type.sizeInComponents; i++)
280 			{
281 				SIMD::Int v = value.Int(i);
282 				dst.move(i, (x & v.xxxx) | (y & v.yyyy) | (z & v.zzzz) | (w & v.wwww));
283 			}
284 			break;
285 		}
286 
287 		case spv::OpGroupNonUniformShuffleUp:
288 		{
289 			GenericValue value(this, state, insn.word(4));
290 			GenericValue delta(this, state, insn.word(5));
291 			auto d0 = CmpEQ(SIMD::Int(0), delta.Int(0));
292 			auto d1 = CmpEQ(SIMD::Int(1), delta.Int(0));
293 			auto d2 = CmpEQ(SIMD::Int(2), delta.Int(0));
294 			auto d3 = CmpEQ(SIMD::Int(3), delta.Int(0));
295 			for(auto i = 0u; i < type.sizeInComponents; i++)
296 			{
297 				SIMD::Int v = value.Int(i);
298 				dst.move(i, (d0 & v.xyzw) | (d1 & v.xxyz) | (d2 & v.xxxy) | (d3 & v.xxxx));
299 			}
300 			break;
301 		}
302 
303 		case spv::OpGroupNonUniformShuffleDown:
304 		{
305 			GenericValue value(this, state, insn.word(4));
306 			GenericValue delta(this, state, insn.word(5));
307 			auto d0 = CmpEQ(SIMD::Int(0), delta.Int(0));
308 			auto d1 = CmpEQ(SIMD::Int(1), delta.Int(0));
309 			auto d2 = CmpEQ(SIMD::Int(2), delta.Int(0));
310 			auto d3 = CmpEQ(SIMD::Int(3), delta.Int(0));
311 			for(auto i = 0u; i < type.sizeInComponents; i++)
312 			{
313 				SIMD::Int v = value.Int(i);
314 				dst.move(i, (d0 & v.xyzw) | (d1 & v.yzww) | (d2 & v.zwww) | (d3 & v.wwww));
315 			}
316 			break;
317 		}
318 
319 		case spv::OpGroupNonUniformIAdd:
320 			Impl::Group::BinaryOperation<SIMD::Int>(
321 			    this, insn, state, dst, 0,
322 			    [](auto a, auto b) { return a + b; });
323 			break;
324 
325 		case spv::OpGroupNonUniformFAdd:
326 			Impl::Group::BinaryOperation<SIMD::Float>(
327 			    this, insn, state, dst, 0.0f,
328 			    [](auto a, auto b) { return a + b; });
329 			break;
330 
331 		case spv::OpGroupNonUniformIMul:
332 			Impl::Group::BinaryOperation<SIMD::Int>(
333 			    this, insn, state, dst, 1,
334 			    [](auto a, auto b) { return a * b; });
335 			break;
336 
337 		case spv::OpGroupNonUniformFMul:
338 			Impl::Group::BinaryOperation<SIMD::Float>(
339 			    this, insn, state, dst, 1.0f,
340 			    [](auto a, auto b) { return a * b; });
341 			break;
342 
343 		case spv::OpGroupNonUniformBitwiseAnd:
344 			Impl::Group::BinaryOperation<SIMD::UInt>(
345 			    this, insn, state, dst, ~0u,
346 			    [](auto a, auto b) { return a & b; });
347 			break;
348 
349 		case spv::OpGroupNonUniformBitwiseOr:
350 			Impl::Group::BinaryOperation<SIMD::UInt>(
351 			    this, insn, state, dst, 0,
352 			    [](auto a, auto b) { return a | b; });
353 			break;
354 
355 		case spv::OpGroupNonUniformBitwiseXor:
356 			Impl::Group::BinaryOperation<SIMD::UInt>(
357 			    this, insn, state, dst, 0,
358 			    [](auto a, auto b) { return a ^ b; });
359 			break;
360 
361 		case spv::OpGroupNonUniformSMin:
362 			Impl::Group::BinaryOperation<SIMD::Int>(
363 			    this, insn, state, dst, INT32_MAX,
364 			    [](auto a, auto b) { return Min(a, b); });
365 			break;
366 
367 		case spv::OpGroupNonUniformUMin:
368 			Impl::Group::BinaryOperation<SIMD::UInt>(
369 			    this, insn, state, dst, ~0u,
370 			    [](auto a, auto b) { return Min(a, b); });
371 			break;
372 
373 		case spv::OpGroupNonUniformFMin:
374 			Impl::Group::BinaryOperation<SIMD::Float>(
375 			    this, insn, state, dst, SIMD::Float::infinity(),
376 			    [](auto a, auto b) { return NMin(a, b); });
377 			break;
378 
379 		case spv::OpGroupNonUniformSMax:
380 			Impl::Group::BinaryOperation<SIMD::Int>(
381 			    this, insn, state, dst, INT32_MIN,
382 			    [](auto a, auto b) { return Max(a, b); });
383 			break;
384 
385 		case spv::OpGroupNonUniformUMax:
386 			Impl::Group::BinaryOperation<SIMD::UInt>(
387 			    this, insn, state, dst, 0,
388 			    [](auto a, auto b) { return Max(a, b); });
389 			break;
390 
391 		case spv::OpGroupNonUniformFMax:
392 			Impl::Group::BinaryOperation<SIMD::Float>(
393 			    this, insn, state, dst, -SIMD::Float::infinity(),
394 			    [](auto a, auto b) { return NMax(a, b); });
395 			break;
396 
397 		case spv::OpGroupNonUniformLogicalAnd:
398 			Impl::Group::BinaryOperation<SIMD::UInt>(
399 			    this, insn, state, dst, ~0u,
400 			    [](auto a, auto b) {
401 				    SIMD::UInt zero = SIMD::UInt(0);
402 				    return CmpNEQ(a, zero) & CmpNEQ(b, zero);
403 			    });
404 			break;
405 
406 		case spv::OpGroupNonUniformLogicalOr:
407 			Impl::Group::BinaryOperation<SIMD::UInt>(
408 			    this, insn, state, dst, 0,
409 			    [](auto a, auto b) {
410 				    SIMD::UInt zero = SIMD::UInt(0);
411 				    return CmpNEQ(a, zero) | CmpNEQ(b, zero);
412 			    });
413 			break;
414 
415 		case spv::OpGroupNonUniformLogicalXor:
416 			Impl::Group::BinaryOperation<SIMD::UInt>(
417 			    this, insn, state, dst, 0,
418 			    [](auto a, auto b) {
419 				    SIMD::UInt zero = SIMD::UInt(0);
420 				    return CmpNEQ(a, zero) ^ CmpNEQ(b, zero);
421 			    });
422 			break;
423 
424 		default:
425 			UNSUPPORTED("EmitGroupNonUniform op: %s", OpcodeName(type.opcode()).c_str());
426 	}
427 	return EmitResult::Continue;
428 }
429 
430 }  // namespace sw
431