1 /*
2 Copyright (c) 2010-2021, Intel Corporation
3 All rights reserved.
4
5 Redistribution and use in source and binary forms, with or without
6 modification, are permitted provided that the following conditions are
7 met:
8
9 * Redistributions of source code must retain the above copyright
10 notice, this list of conditions and the following disclaimer.
11
12 * Redistributions in binary form must reproduce the above copyright
13 notice, this list of conditions and the following disclaimer in the
14 documentation and/or other materials provided with the distribution.
15
16 * Neither the name of Intel Corporation nor the names of its
17 contributors may be used to endorse or promote products derived from
18 this software without specific prior written permission.
19
20
21 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
22 IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
23 TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
24 PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
25 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 /** @file llvmutil.cpp
35 @brief Implementations of various LLVM utility types and classes.
36 */
37
38 #include "llvmutil.h"
39 #include "ispc.h"
40 #include "type.h"
41
42 #include <map>
43 #include <set>
44 #include <vector>
45
46 #include <llvm/Analysis/ValueTracking.h>
47 #include <llvm/IR/BasicBlock.h>
48 #include <llvm/IR/Instructions.h>
49 #include <llvm/IR/Module.h>
50
51 #ifdef ISPC_GENX_ENABLED
52 #include <llvm/GenXIntrinsics/GenXIntrinsics.h>
53 #endif
54
55 namespace ispc {
56
57 llvm::Type *LLVMTypes::VoidType = NULL;
58 llvm::PointerType *LLVMTypes::VoidPointerType = NULL;
59 llvm::Type *LLVMTypes::PointerIntType = NULL;
60 llvm::Type *LLVMTypes::BoolType = NULL;
61 llvm::Type *LLVMTypes::BoolStorageType = NULL;
62
63 llvm::Type *LLVMTypes::Int8Type = NULL;
64 llvm::Type *LLVMTypes::Int16Type = NULL;
65 llvm::Type *LLVMTypes::Int32Type = NULL;
66 llvm::Type *LLVMTypes::Int64Type = NULL;
67 llvm::Type *LLVMTypes::FloatType = NULL;
68 llvm::Type *LLVMTypes::DoubleType = NULL;
69
70 llvm::Type *LLVMTypes::Int8PointerType = NULL;
71 llvm::Type *LLVMTypes::Int16PointerType = NULL;
72 llvm::Type *LLVMTypes::Int32PointerType = NULL;
73 llvm::Type *LLVMTypes::Int64PointerType = NULL;
74 llvm::Type *LLVMTypes::FloatPointerType = NULL;
75 llvm::Type *LLVMTypes::DoublePointerType = NULL;
76
77 llvm::VectorType *LLVMTypes::MaskType = NULL;
78 llvm::VectorType *LLVMTypes::BoolVectorType = NULL;
79 llvm::VectorType *LLVMTypes::BoolVectorStorageType = NULL;
80
81 llvm::VectorType *LLVMTypes::Int1VectorType = NULL;
82 llvm::VectorType *LLVMTypes::Int8VectorType = NULL;
83 llvm::VectorType *LLVMTypes::Int16VectorType = NULL;
84 llvm::VectorType *LLVMTypes::Int32VectorType = NULL;
85 llvm::VectorType *LLVMTypes::Int64VectorType = NULL;
86 llvm::VectorType *LLVMTypes::FloatVectorType = NULL;
87 llvm::VectorType *LLVMTypes::DoubleVectorType = NULL;
88
89 llvm::Type *LLVMTypes::Int8VectorPointerType = NULL;
90 llvm::Type *LLVMTypes::Int16VectorPointerType = NULL;
91 llvm::Type *LLVMTypes::Int32VectorPointerType = NULL;
92 llvm::Type *LLVMTypes::Int64VectorPointerType = NULL;
93 llvm::Type *LLVMTypes::FloatVectorPointerType = NULL;
94 llvm::Type *LLVMTypes::DoubleVectorPointerType = NULL;
95
96 llvm::VectorType *LLVMTypes::VoidPointerVectorType = NULL;
97
98 llvm::Constant *LLVMTrue = NULL;
99 llvm::Constant *LLVMFalse = NULL;
100 llvm::Constant *LLVMTrueInStorage = NULL;
101 llvm::Constant *LLVMFalseInStorage = NULL;
102 llvm::Constant *LLVMMaskAllOn = NULL;
103 llvm::Constant *LLVMMaskAllOff = NULL;
104
InitLLVMUtil(llvm::LLVMContext * ctx,Target & target)105 void InitLLVMUtil(llvm::LLVMContext *ctx, Target &target) {
106 LLVMTypes::VoidType = llvm::Type::getVoidTy(*ctx);
107 LLVMTypes::VoidPointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(*ctx), 0);
108 LLVMTypes::PointerIntType = target.is32Bit() ? llvm::Type::getInt32Ty(*ctx) : llvm::Type::getInt64Ty(*ctx);
109
110 LLVMTypes::BoolType = llvm::Type::getInt1Ty(*ctx);
111 LLVMTypes::Int8Type = LLVMTypes::BoolStorageType = llvm::Type::getInt8Ty(*ctx);
112 LLVMTypes::Int16Type = llvm::Type::getInt16Ty(*ctx);
113 LLVMTypes::Int32Type = llvm::Type::getInt32Ty(*ctx);
114 LLVMTypes::Int64Type = llvm::Type::getInt64Ty(*ctx);
115 LLVMTypes::FloatType = llvm::Type::getFloatTy(*ctx);
116 LLVMTypes::DoubleType = llvm::Type::getDoubleTy(*ctx);
117
118 LLVMTypes::Int8PointerType = llvm::PointerType::get(LLVMTypes::Int8Type, 0);
119 LLVMTypes::Int16PointerType = llvm::PointerType::get(LLVMTypes::Int16Type, 0);
120 LLVMTypes::Int32PointerType = llvm::PointerType::get(LLVMTypes::Int32Type, 0);
121 LLVMTypes::Int64PointerType = llvm::PointerType::get(LLVMTypes::Int64Type, 0);
122 LLVMTypes::FloatPointerType = llvm::PointerType::get(LLVMTypes::FloatType, 0);
123 LLVMTypes::DoublePointerType = llvm::PointerType::get(LLVMTypes::DoubleType, 0);
124
125 switch (target.getMaskBitCount()) {
126 case 1:
127 LLVMTypes::MaskType = LLVMTypes::BoolVectorType =
128 LLVMVECTOR::get(llvm::Type::getInt1Ty(*ctx), target.getVectorWidth());
129 break;
130 case 8:
131 LLVMTypes::MaskType = LLVMTypes::BoolVectorType =
132 LLVMVECTOR::get(llvm::Type::getInt8Ty(*ctx), target.getVectorWidth());
133 break;
134 case 16:
135 LLVMTypes::MaskType = LLVMTypes::BoolVectorType =
136 LLVMVECTOR::get(llvm::Type::getInt16Ty(*ctx), target.getVectorWidth());
137 break;
138 case 32:
139 LLVMTypes::MaskType = LLVMTypes::BoolVectorType =
140 LLVMVECTOR::get(llvm::Type::getInt32Ty(*ctx), target.getVectorWidth());
141 break;
142 case 64:
143 LLVMTypes::MaskType = LLVMTypes::BoolVectorType =
144 LLVMVECTOR::get(llvm::Type::getInt64Ty(*ctx), target.getVectorWidth());
145 break;
146 default:
147 FATAL("Unhandled mask width for initializing MaskType");
148 }
149
150 LLVMTypes::Int1VectorType = LLVMVECTOR::get(llvm::Type::getInt1Ty(*ctx), target.getVectorWidth());
151 LLVMTypes::Int8VectorType = LLVMTypes::BoolVectorStorageType =
152 LLVMVECTOR::get(LLVMTypes::Int8Type, target.getVectorWidth());
153 LLVMTypes::Int16VectorType = LLVMVECTOR::get(LLVMTypes::Int16Type, target.getVectorWidth());
154 LLVMTypes::Int32VectorType = LLVMVECTOR::get(LLVMTypes::Int32Type, target.getVectorWidth());
155 LLVMTypes::Int64VectorType = LLVMVECTOR::get(LLVMTypes::Int64Type, target.getVectorWidth());
156 LLVMTypes::FloatVectorType = LLVMVECTOR::get(LLVMTypes::FloatType, target.getVectorWidth());
157 LLVMTypes::DoubleVectorType = LLVMVECTOR::get(LLVMTypes::DoubleType, target.getVectorWidth());
158
159 LLVMTypes::Int8VectorPointerType = llvm::PointerType::get(LLVMTypes::Int8VectorType, 0);
160 LLVMTypes::Int16VectorPointerType = llvm::PointerType::get(LLVMTypes::Int16VectorType, 0);
161 LLVMTypes::Int32VectorPointerType = llvm::PointerType::get(LLVMTypes::Int32VectorType, 0);
162 LLVMTypes::Int64VectorPointerType = llvm::PointerType::get(LLVMTypes::Int64VectorType, 0);
163 LLVMTypes::FloatVectorPointerType = llvm::PointerType::get(LLVMTypes::FloatVectorType, 0);
164 LLVMTypes::DoubleVectorPointerType = llvm::PointerType::get(LLVMTypes::DoubleVectorType, 0);
165
166 LLVMTypes::VoidPointerVectorType = g->target->is32Bit() ? LLVMTypes::Int32VectorType : LLVMTypes::Int64VectorType;
167
168 LLVMTrue = llvm::ConstantInt::getTrue(*ctx);
169 LLVMFalse = llvm::ConstantInt::getFalse(*ctx);
170 LLVMTrueInStorage = llvm::ConstantInt::get(LLVMTypes::Int8Type, 0xff, false /*unsigned*/);
171 LLVMFalseInStorage = llvm::ConstantInt::get(LLVMTypes::Int8Type, 0x00, false /*unsigned*/);
172
173 std::vector<llvm::Constant *> maskOnes;
174 llvm::Constant *onMask = NULL;
175 switch (target.getMaskBitCount()) {
176 case 1:
177 onMask = llvm::ConstantInt::get(llvm::Type::getInt1Ty(*ctx), 1, false /*unsigned*/); // 0x1
178 break;
179 case 8:
180 onMask = llvm::ConstantInt::get(llvm::Type::getInt8Ty(*ctx), -1, true /*signed*/); // 0xff
181 break;
182 case 16:
183 onMask = llvm::ConstantInt::get(llvm::Type::getInt16Ty(*ctx), -1, true /*signed*/); // 0xffff
184 break;
185 case 32:
186 onMask = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*ctx), -1, true /*signed*/); // 0xffffffff
187 break;
188 case 64:
189 onMask = llvm::ConstantInt::get(llvm::Type::getInt64Ty(*ctx), -1, true /*signed*/); // 0xffffffffffffffffull
190 break;
191 default:
192 FATAL("Unhandled mask width for onMask");
193 }
194
195 for (int i = 0; i < target.getVectorWidth(); ++i)
196 maskOnes.push_back(onMask);
197 LLVMMaskAllOn = llvm::ConstantVector::get(maskOnes);
198
199 std::vector<llvm::Constant *> maskZeros;
200 llvm::Constant *offMask = NULL;
201 switch (target.getMaskBitCount()) {
202 case 1:
203 offMask = llvm::ConstantInt::get(llvm::Type::getInt1Ty(*ctx), 0, true /*signed*/);
204 break;
205 case 8:
206 offMask = llvm::ConstantInt::get(llvm::Type::getInt8Ty(*ctx), 0, true /*signed*/);
207 break;
208 case 16:
209 offMask = llvm::ConstantInt::get(llvm::Type::getInt16Ty(*ctx), 0, true /*signed*/);
210 break;
211 case 32:
212 offMask = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*ctx), 0, true /*signed*/);
213 break;
214 case 64:
215 offMask = llvm::ConstantInt::get(llvm::Type::getInt64Ty(*ctx), 0, true /*signed*/);
216 break;
217 default:
218 FATAL("Unhandled mask width for offMask");
219 }
220 for (int i = 0; i < target.getVectorWidth(); ++i)
221 maskZeros.push_back(offMask);
222 LLVMMaskAllOff = llvm::ConstantVector::get(maskZeros);
223 }
224
LLVMInt8(int8_t ival)225 llvm::ConstantInt *LLVMInt8(int8_t ival) {
226 return llvm::ConstantInt::get(llvm::Type::getInt8Ty(*g->ctx), ival, true /*signed*/);
227 }
228
LLVMUInt8(uint8_t ival)229 llvm::ConstantInt *LLVMUInt8(uint8_t ival) {
230 return llvm::ConstantInt::get(llvm::Type::getInt8Ty(*g->ctx), ival, false /*unsigned*/);
231 }
232
LLVMInt16(int16_t ival)233 llvm::ConstantInt *LLVMInt16(int16_t ival) {
234 return llvm::ConstantInt::get(llvm::Type::getInt16Ty(*g->ctx), ival, true /*signed*/);
235 }
236
LLVMUInt16(uint16_t ival)237 llvm::ConstantInt *LLVMUInt16(uint16_t ival) {
238 return llvm::ConstantInt::get(llvm::Type::getInt16Ty(*g->ctx), ival, false /*unsigned*/);
239 }
240
LLVMInt32(int32_t ival)241 llvm::ConstantInt *LLVMInt32(int32_t ival) {
242 return llvm::ConstantInt::get(llvm::Type::getInt32Ty(*g->ctx), ival, true /*signed*/);
243 }
244
LLVMUInt32(uint32_t ival)245 llvm::ConstantInt *LLVMUInt32(uint32_t ival) {
246 return llvm::ConstantInt::get(llvm::Type::getInt32Ty(*g->ctx), ival, false /*unsigned*/);
247 }
248
LLVMInt64(int64_t ival)249 llvm::ConstantInt *LLVMInt64(int64_t ival) {
250 return llvm::ConstantInt::get(llvm::Type::getInt64Ty(*g->ctx), ival, true /*signed*/);
251 }
252
LLVMUInt64(uint64_t ival)253 llvm::ConstantInt *LLVMUInt64(uint64_t ival) {
254 return llvm::ConstantInt::get(llvm::Type::getInt64Ty(*g->ctx), ival, false /*unsigned*/);
255 }
256
LLVMFloat(float fval)257 llvm::Constant *LLVMFloat(float fval) { return llvm::ConstantFP::get(llvm::Type::getFloatTy(*g->ctx), fval); }
258
LLVMDouble(double dval)259 llvm::Constant *LLVMDouble(double dval) { return llvm::ConstantFP::get(llvm::Type::getDoubleTy(*g->ctx), dval); }
260
LLVMInt8Vector(int8_t ival)261 llvm::Constant *LLVMInt8Vector(int8_t ival) {
262 llvm::Constant *v = LLVMInt8(ival);
263 std::vector<llvm::Constant *> vals;
264 for (int i = 0; i < g->target->getVectorWidth(); ++i)
265 vals.push_back(v);
266 return llvm::ConstantVector::get(vals);
267 }
268
LLVMInt8Vector(const int8_t * ivec)269 llvm::Constant *LLVMInt8Vector(const int8_t *ivec) {
270 std::vector<llvm::Constant *> vals;
271 for (int i = 0; i < g->target->getVectorWidth(); ++i)
272 vals.push_back(LLVMInt8(ivec[i]));
273 return llvm::ConstantVector::get(vals);
274 }
275
LLVMUInt8Vector(uint8_t ival)276 llvm::Constant *LLVMUInt8Vector(uint8_t ival) {
277 llvm::Constant *v = LLVMUInt8(ival);
278 std::vector<llvm::Constant *> vals;
279 for (int i = 0; i < g->target->getVectorWidth(); ++i)
280 vals.push_back(v);
281 return llvm::ConstantVector::get(vals);
282 }
283
LLVMUInt8Vector(const uint8_t * ivec)284 llvm::Constant *LLVMUInt8Vector(const uint8_t *ivec) {
285 std::vector<llvm::Constant *> vals;
286 for (int i = 0; i < g->target->getVectorWidth(); ++i)
287 vals.push_back(LLVMUInt8(ivec[i]));
288 return llvm::ConstantVector::get(vals);
289 }
290
LLVMInt16Vector(int16_t ival)291 llvm::Constant *LLVMInt16Vector(int16_t ival) {
292 llvm::Constant *v = LLVMInt16(ival);
293 std::vector<llvm::Constant *> vals;
294 for (int i = 0; i < g->target->getVectorWidth(); ++i)
295 vals.push_back(v);
296 return llvm::ConstantVector::get(vals);
297 }
298
LLVMInt16Vector(const int16_t * ivec)299 llvm::Constant *LLVMInt16Vector(const int16_t *ivec) {
300 std::vector<llvm::Constant *> vals;
301 for (int i = 0; i < g->target->getVectorWidth(); ++i)
302 vals.push_back(LLVMInt16(ivec[i]));
303 return llvm::ConstantVector::get(vals);
304 }
305
LLVMUInt16Vector(uint16_t ival)306 llvm::Constant *LLVMUInt16Vector(uint16_t ival) {
307 llvm::Constant *v = LLVMUInt16(ival);
308 std::vector<llvm::Constant *> vals;
309 for (int i = 0; i < g->target->getVectorWidth(); ++i)
310 vals.push_back(v);
311 return llvm::ConstantVector::get(vals);
312 }
313
LLVMUInt16Vector(const uint16_t * ivec)314 llvm::Constant *LLVMUInt16Vector(const uint16_t *ivec) {
315 std::vector<llvm::Constant *> vals;
316 for (int i = 0; i < g->target->getVectorWidth(); ++i)
317 vals.push_back(LLVMUInt16(ivec[i]));
318 return llvm::ConstantVector::get(vals);
319 }
320
LLVMInt32Vector(int32_t ival)321 llvm::Constant *LLVMInt32Vector(int32_t ival) {
322 llvm::Constant *v = LLVMInt32(ival);
323 std::vector<llvm::Constant *> vals;
324 for (int i = 0; i < g->target->getVectorWidth(); ++i)
325 vals.push_back(v);
326 return llvm::ConstantVector::get(vals);
327 }
328
LLVMInt32Vector(const int32_t * ivec)329 llvm::Constant *LLVMInt32Vector(const int32_t *ivec) {
330 std::vector<llvm::Constant *> vals;
331 for (int i = 0; i < g->target->getVectorWidth(); ++i)
332 vals.push_back(LLVMInt32(ivec[i]));
333 return llvm::ConstantVector::get(vals);
334 }
335
LLVMUInt32Vector(uint32_t ival)336 llvm::Constant *LLVMUInt32Vector(uint32_t ival) {
337 llvm::Constant *v = LLVMUInt32(ival);
338 std::vector<llvm::Constant *> vals;
339 for (int i = 0; i < g->target->getVectorWidth(); ++i)
340 vals.push_back(v);
341 return llvm::ConstantVector::get(vals);
342 }
343
LLVMUInt32Vector(const uint32_t * ivec)344 llvm::Constant *LLVMUInt32Vector(const uint32_t *ivec) {
345 std::vector<llvm::Constant *> vals;
346 for (int i = 0; i < g->target->getVectorWidth(); ++i)
347 vals.push_back(LLVMUInt32(ivec[i]));
348 return llvm::ConstantVector::get(vals);
349 }
350
LLVMFloatVector(float fval)351 llvm::Constant *LLVMFloatVector(float fval) {
352 llvm::Constant *v = LLVMFloat(fval);
353 std::vector<llvm::Constant *> vals;
354 for (int i = 0; i < g->target->getVectorWidth(); ++i)
355 vals.push_back(v);
356 return llvm::ConstantVector::get(vals);
357 }
358
LLVMFloatVector(const float * fvec)359 llvm::Constant *LLVMFloatVector(const float *fvec) {
360 std::vector<llvm::Constant *> vals;
361 for (int i = 0; i < g->target->getVectorWidth(); ++i)
362 vals.push_back(LLVMFloat(fvec[i]));
363 return llvm::ConstantVector::get(vals);
364 }
365
LLVMDoubleVector(double dval)366 llvm::Constant *LLVMDoubleVector(double dval) {
367 llvm::Constant *v = LLVMDouble(dval);
368 std::vector<llvm::Constant *> vals;
369 for (int i = 0; i < g->target->getVectorWidth(); ++i)
370 vals.push_back(v);
371 return llvm::ConstantVector::get(vals);
372 }
373
LLVMDoubleVector(const double * dvec)374 llvm::Constant *LLVMDoubleVector(const double *dvec) {
375 std::vector<llvm::Constant *> vals;
376 for (int i = 0; i < g->target->getVectorWidth(); ++i)
377 vals.push_back(LLVMDouble(dvec[i]));
378 return llvm::ConstantVector::get(vals);
379 }
380
LLVMInt64Vector(int64_t ival)381 llvm::Constant *LLVMInt64Vector(int64_t ival) {
382 llvm::Constant *v = LLVMInt64(ival);
383 std::vector<llvm::Constant *> vals;
384 for (int i = 0; i < g->target->getVectorWidth(); ++i)
385 vals.push_back(v);
386 return llvm::ConstantVector::get(vals);
387 }
388
LLVMInt64Vector(const int64_t * ivec)389 llvm::Constant *LLVMInt64Vector(const int64_t *ivec) {
390 std::vector<llvm::Constant *> vals;
391 for (int i = 0; i < g->target->getVectorWidth(); ++i)
392 vals.push_back(LLVMInt64(ivec[i]));
393 return llvm::ConstantVector::get(vals);
394 }
395
LLVMUInt64Vector(uint64_t ival)396 llvm::Constant *LLVMUInt64Vector(uint64_t ival) {
397 llvm::Constant *v = LLVMUInt64(ival);
398 std::vector<llvm::Constant *> vals;
399 for (int i = 0; i < g->target->getVectorWidth(); ++i)
400 vals.push_back(v);
401 return llvm::ConstantVector::get(vals);
402 }
403
LLVMUInt64Vector(const uint64_t * ivec)404 llvm::Constant *LLVMUInt64Vector(const uint64_t *ivec) {
405 std::vector<llvm::Constant *> vals;
406 for (int i = 0; i < g->target->getVectorWidth(); ++i)
407 vals.push_back(LLVMUInt64(ivec[i]));
408 return llvm::ConstantVector::get(vals);
409 }
410
LLVMBoolVector(bool b)411 llvm::Constant *LLVMBoolVector(bool b) {
412 llvm::Constant *v;
413 if (LLVMTypes::BoolVectorType == LLVMTypes::Int64VectorType)
414 v = llvm::ConstantInt::get(LLVMTypes::Int64Type, b ? 0xffffffffffffffffull : 0, false /*unsigned*/);
415 else if (LLVMTypes::BoolVectorType == LLVMTypes::Int32VectorType)
416 v = llvm::ConstantInt::get(LLVMTypes::Int32Type, b ? 0xffffffff : 0, false /*unsigned*/);
417 else if (LLVMTypes::BoolVectorType == LLVMTypes::Int16VectorType)
418 v = llvm::ConstantInt::get(LLVMTypes::Int16Type, b ? 0xffff : 0, false /*unsigned*/);
419 else if (LLVMTypes::BoolVectorType == LLVMTypes::Int8VectorType)
420 v = llvm::ConstantInt::get(LLVMTypes::Int8Type, b ? 0xff : 0, false /*unsigned*/);
421 else {
422 Assert(LLVMTypes::BoolVectorType == LLVMTypes::Int1VectorType);
423 v = b ? LLVMTrue : LLVMFalse;
424 }
425
426 std::vector<llvm::Constant *> vals;
427 for (int i = 0; i < g->target->getVectorWidth(); ++i)
428 vals.push_back(v);
429 return llvm::ConstantVector::get(vals);
430 }
431
LLVMBoolVector(const bool * bvec)432 llvm::Constant *LLVMBoolVector(const bool *bvec) {
433 std::vector<llvm::Constant *> vals;
434 for (int i = 0; i < g->target->getVectorWidth(); ++i) {
435 llvm::Constant *v;
436 if (LLVMTypes::BoolVectorType == LLVMTypes::Int64VectorType)
437 v = llvm::ConstantInt::get(LLVMTypes::Int64Type, bvec[i] ? 0xffffffffffffffffull : 0, false /*unsigned*/);
438 else if (LLVMTypes::BoolVectorType == LLVMTypes::Int32VectorType)
439 v = llvm::ConstantInt::get(LLVMTypes::Int32Type, bvec[i] ? 0xffffffff : 0, false /*unsigned*/);
440 else if (LLVMTypes::BoolVectorType == LLVMTypes::Int16VectorType)
441 v = llvm::ConstantInt::get(LLVMTypes::Int16Type, bvec[i] ? 0xffff : 0, false /*unsigned*/);
442 else if (LLVMTypes::BoolVectorType == LLVMTypes::Int8VectorType)
443 v = llvm::ConstantInt::get(LLVMTypes::Int8Type, bvec[i] ? 0xff : 0, false /*unsigned*/);
444 else {
445 Assert(LLVMTypes::BoolVectorType == LLVMTypes::Int1VectorType);
446 v = bvec[i] ? LLVMTrue : LLVMFalse;
447 }
448
449 vals.push_back(v);
450 }
451 return llvm::ConstantVector::get(vals);
452 }
453
LLVMBoolVectorInStorage(bool b)454 llvm::Constant *LLVMBoolVectorInStorage(bool b) {
455 llvm::Constant *v = b ? LLVMTrueInStorage : LLVMFalseInStorage;
456 std::vector<llvm::Constant *> vals;
457 for (int i = 0; i < g->target->getVectorWidth(); ++i)
458 vals.push_back(v);
459 return llvm::ConstantVector::get(vals);
460 }
461
LLVMBoolVectorInStorage(const bool * bvec)462 llvm::Constant *LLVMBoolVectorInStorage(const bool *bvec) {
463 std::vector<llvm::Constant *> vals;
464 for (int i = 0; i < g->target->getVectorWidth(); ++i) {
465 llvm::Constant *v = llvm::ConstantInt::get(LLVMTypes::Int8Type, bvec[i] ? 0xff : 0, false /*unsigned*/);
466 vals.push_back(v);
467 }
468 return llvm::ConstantVector::get(vals);
469 }
470
LLVMIntAsType(int64_t val,llvm::Type * type)471 llvm::Constant *LLVMIntAsType(int64_t val, llvm::Type *type) {
472 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
473 llvm::FixedVectorType *vecType = llvm::dyn_cast<llvm::FixedVectorType>(type);
474 #else
475 llvm::VectorType *vecType = llvm::dyn_cast<llvm::VectorType>(type);
476 #endif
477
478 if (vecType != NULL) {
479 llvm::Constant *v = llvm::ConstantInt::get(vecType->getElementType(), val, true /* signed */);
480 std::vector<llvm::Constant *> vals;
481 for (int i = 0; i < (int)vecType->getNumElements(); ++i)
482 vals.push_back(v);
483 return llvm::ConstantVector::get(vals);
484 } else
485 return llvm::ConstantInt::get(type, val, true /* signed */);
486 }
487
LLVMUIntAsType(uint64_t val,llvm::Type * type)488 llvm::Constant *LLVMUIntAsType(uint64_t val, llvm::Type *type) {
489 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
490 llvm::FixedVectorType *vecType = llvm::dyn_cast<llvm::FixedVectorType>(type);
491 #else
492 llvm::VectorType *vecType = llvm::dyn_cast<llvm::VectorType>(type);
493 #endif
494
495 if (vecType != NULL) {
496 llvm::Constant *v = llvm::ConstantInt::get(vecType->getElementType(), val, false /* unsigned */);
497 std::vector<llvm::Constant *> vals;
498 for (int i = 0; i < (int)vecType->getNumElements(); ++i)
499 vals.push_back(v);
500 return llvm::ConstantVector::get(vals);
501 } else
502 return llvm::ConstantInt::get(type, val, false /* unsigned */);
503 }
504
505 /** Conservative test to see if two llvm::Values are equal. There are
506 (potentially many) cases where the two values actually are equal but
507 this will return false. However, if it does return true, the two
508 vectors definitely are equal.
509 */
lValuesAreEqual(llvm::Value * v0,llvm::Value * v1,std::vector<llvm::PHINode * > & seenPhi0,std::vector<llvm::PHINode * > & seenPhi1)510 static bool lValuesAreEqual(llvm::Value *v0, llvm::Value *v1, std::vector<llvm::PHINode *> &seenPhi0,
511 std::vector<llvm::PHINode *> &seenPhi1) {
512 // Thanks to the fact that LLVM hashes and returns the same pointer for
513 // constants (of all sorts, even constant expressions), this first test
514 // actually catches a lot of cases. LLVM's SSA form also helps a lot
515 // with this..
516 if (v0 == v1)
517 return true;
518
519 Assert(seenPhi0.size() == seenPhi1.size());
520 for (unsigned int i = 0; i < seenPhi0.size(); ++i)
521 if (v0 == seenPhi0[i] && v1 == seenPhi1[i])
522 return true;
523
524 llvm::BinaryOperator *bo0 = llvm::dyn_cast<llvm::BinaryOperator>(v0);
525 llvm::BinaryOperator *bo1 = llvm::dyn_cast<llvm::BinaryOperator>(v1);
526 if (bo0 != NULL && bo1 != NULL) {
527 if (bo0->getOpcode() != bo1->getOpcode())
528 return false;
529 return (lValuesAreEqual(bo0->getOperand(0), bo1->getOperand(0), seenPhi0, seenPhi1) &&
530 lValuesAreEqual(bo0->getOperand(1), bo1->getOperand(1), seenPhi0, seenPhi1));
531 }
532
533 llvm::CastInst *cast0 = llvm::dyn_cast<llvm::CastInst>(v0);
534 llvm::CastInst *cast1 = llvm::dyn_cast<llvm::CastInst>(v1);
535 if (cast0 != NULL && cast1 != NULL) {
536 if (cast0->getOpcode() != cast1->getOpcode())
537 return false;
538 return lValuesAreEqual(cast0->getOperand(0), cast1->getOperand(0), seenPhi0, seenPhi1);
539 }
540
541 llvm::PHINode *phi0 = llvm::dyn_cast<llvm::PHINode>(v0);
542 llvm::PHINode *phi1 = llvm::dyn_cast<llvm::PHINode>(v1);
543 if (phi0 != NULL && phi1 != NULL) {
544 if (phi0->getNumIncomingValues() != phi1->getNumIncomingValues())
545 return false;
546
547 seenPhi0.push_back(phi0);
548 seenPhi1.push_back(phi1);
549
550 unsigned int numIncoming = phi0->getNumIncomingValues();
551 // Check all of the incoming values: if all of them are all equal,
552 // then we're good.
553 bool anyFailure = false;
554 for (unsigned int i = 0; i < numIncoming; ++i) {
555 // FIXME: should it be ok if the incoming blocks are different,
556 // where we just return faliure in this case?
557 Assert(phi0->getIncomingBlock(i) == phi1->getIncomingBlock(i));
558 if (!lValuesAreEqual(phi0->getIncomingValue(i), phi1->getIncomingValue(i), seenPhi0, seenPhi1)) {
559 anyFailure = true;
560 break;
561 }
562 }
563
564 seenPhi0.pop_back();
565 seenPhi1.pop_back();
566
567 return !anyFailure;
568 }
569
570 return false;
571 }
572
573 /** Given an llvm::Value known to be an integer, return its value as
574 an int64_t.
575 */
lGetIntValue(llvm::Value * offset)576 static int64_t lGetIntValue(llvm::Value *offset) {
577 llvm::ConstantInt *intOffset = llvm::dyn_cast<llvm::ConstantInt>(offset);
578 Assert(intOffset && (intOffset->getBitWidth() == 32 || intOffset->getBitWidth() == 64));
579 return intOffset->getSExtValue();
580 }
581
582 /** Recognizes constant vector with undef operands except the first one:
583 * <i64 4, i64 undef, i64 undef, i64 undef, i64 undef, i64 undef, i64 undef, i64 undef>
584 */
lIsFirstElementConstVector(llvm::Value * v)585 static bool lIsFirstElementConstVector(llvm::Value *v) {
586 llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(v);
587 // FIXME: skipping instruction and using getOperand(1) without checking for instruction type is incorrect!
588 // This yields failing tests/write-same-loc.ispc for 32 bit targets (x86).
589 // Need to understand what initial intent was here (what instruction supposed to be handled).
590 // TODO: after fixing FIXME above isGenXTarget() needs to be removed.
591 if (g->target->isGenXTarget()) {
592 if (cv == NULL && llvm::isa<llvm::Instruction>(v)) {
593 cv = llvm::dyn_cast<llvm::ConstantVector>(llvm::dyn_cast<llvm::Instruction>(v)->getOperand(1));
594 }
595 }
596 if (cv != NULL) {
597 llvm::Constant *c = llvm::dyn_cast<llvm::Constant>(cv->getOperand(0));
598 if (c == NULL) {
599 return false;
600 }
601
602 for (int i = 1; i < (int)cv->getNumOperands(); ++i) {
603 if (!llvm::isa<llvm::UndefValue>(cv->getOperand(i))) {
604 return false;
605 }
606 }
607 return true;
608 }
609 return false;
610 }
611
LLVMFlattenInsertChain(llvm::Value * inst,int vectorWidth,bool compare,bool undef,bool searchFirstUndef)612 llvm::Value *LLVMFlattenInsertChain(llvm::Value *inst, int vectorWidth, bool compare, bool undef,
613 bool searchFirstUndef) {
614 std::vector<llvm::Value *> elements(vectorWidth, nullptr);
615
616 // Catch a pattern of InsertElement chain.
617 if (llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(inst)) {
618 // Gather elements of vector
619 while (ie != NULL) {
620 int64_t iOffset = lGetIntValue(ie->getOperand(2));
621 Assert(iOffset >= 0 && iOffset < vectorWidth);
622
623 // Get the scalar value from this insert
624 if (elements[iOffset] == NULL) {
625 elements[iOffset] = ie->getOperand(1);
626 }
627
628 // Do we have another insert?
629 llvm::Value *insertBase = ie->getOperand(0);
630 ie = llvm::dyn_cast<llvm::InsertElementInst>(insertBase);
631 if (ie != NULL) {
632 continue;
633 }
634
635 if (llvm::isa<llvm::UndefValue>(insertBase)) {
636 break;
637 }
638
639 if (llvm::isa<llvm::ConstantVector>(insertBase) || llvm::isa<llvm::ConstantAggregateZero>(insertBase)) {
640 llvm::Constant *cv = llvm::dyn_cast<llvm::Constant>(insertBase);
641 Assert(vectorWidth == (int)(cv->getNumOperands()));
642 for (int i = 0; i < vectorWidth; i++) {
643 if (elements[i] == NULL) {
644 elements[i] = cv->getOperand(i);
645 }
646 }
647 break;
648 } else {
649 // Here chain ends in llvm::LoadInst or some other.
650 // They are not equal to each other so we should return NULL if compare
651 // and first element if we have it.
652 Assert(compare == true || elements[0] != NULL);
653 if (compare) {
654 return NULL;
655 } else {
656 return elements[0];
657 }
658 }
659 // TODO: Also, should we handle some other values like
660 // ConstantDataVectors.
661 }
662 if (compare == false) {
663 // We simply want first element
664 return elements[0];
665 }
666
667 int null_number = 0;
668 int NonNull = 0;
669 for (int i = 0; i < vectorWidth; i++) {
670 if (elements[i] == NULL) {
671 null_number++;
672 } else {
673 NonNull = i;
674 }
675 }
676 if (null_number == vectorWidth) {
677 // All of elements are NULLs
678 return NULL;
679 }
680 if ((undef == false) && (null_number != 0)) {
681 // We don't want NULLs in chain, but we have them
682 return NULL;
683 }
684
685 // Compare elements of vector
686 for (int i = 0; i < vectorWidth; i++) {
687 if (elements[i] == NULL) {
688 continue;
689 }
690
691 std::vector<llvm::PHINode *> seenPhi0;
692 std::vector<llvm::PHINode *> seenPhi1;
693 if (lValuesAreEqual(elements[NonNull], elements[i], seenPhi0, seenPhi1) == false) {
694 return NULL;
695 }
696 }
697 return elements[NonNull];
698 }
699
700 // Catch a pattern of broadcast implemented as InsertElement + Shuffle:
701 // %broadcast_init.0 = insertelement <4 x i32> undef, i32 %val, i32 0
702 // %broadcast.1 = shufflevector <4 x i32> %smear.0, <4 x i32> undef,
703 // <4 x i32> zeroinitializer
704 // Or:
705 // %gep_ptr2int_broadcast_init = insertelement <8 x i64> undef, i64 %gep_ptr2int, i32 0
706 // %0 = add <8 x i64> %gep_ptr2int_broadcast_init,
707 // <i64 4, i64 undef, i64 undef, i64 undef, i64 undef, i64 undef, i64 undef, i64 undef>
708 // %gep_offset = shufflevector <8 x i64> %0, <8 x i64> undef, <8 x i32> zeroinitializer
709 else if (llvm::ShuffleVectorInst *shuf = llvm::dyn_cast<llvm::ShuffleVectorInst>(inst)) {
710 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
711 llvm::Value *indices = shuf->getShuffleMaskForBitcode();
712 #else
713 llvm::Value *indices = shuf->getOperand(2);
714 #endif
715
716 if (llvm::isa<llvm::ConstantAggregateZero>(indices)) {
717 llvm::Value *op = shuf->getOperand(0);
718 llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(op);
719 if (ie == NULL && searchFirstUndef) {
720 // Trying to recognize 2nd pattern
721 llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(op);
722 if (bop != NULL && ((bop->getOpcode() == llvm::Instruction::Add) || IsOrEquivalentToAdd(bop))) {
723 if (lIsFirstElementConstVector(bop->getOperand(1))) {
724 ie = llvm::dyn_cast<llvm::InsertElementInst>(bop->getOperand(0));
725 } else if (llvm::isa<llvm::InsertElementInst>(bop->getOperand(1))) {
726 // Or shuffle vector can accept insertelement itself
727 ie = llvm::cast<llvm::InsertElementInst>(bop->getOperand(1));
728 }
729 }
730 }
731 if (ie != NULL && llvm::isa<llvm::UndefValue>(ie->getOperand(0))) {
732 llvm::ConstantInt *ci = llvm::dyn_cast<llvm::ConstantInt>(ie->getOperand(2));
733 if (ci->isZero()) {
734 return ie->getOperand(1);
735 }
736 }
737 }
738 }
739 return NULL;
740 }
741
LLVMExtractVectorInts(llvm::Value * v,int64_t ret[],int * nElts)742 bool LLVMExtractVectorInts(llvm::Value *v, int64_t ret[], int *nElts) {
743 // Make sure we do in fact have a vector of integer values here
744 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
745 llvm::FixedVectorType *vt = llvm::dyn_cast<llvm::FixedVectorType>(v->getType());
746 #else
747 llvm::VectorType *vt = llvm::dyn_cast<llvm::VectorType>(v->getType());
748 #endif
749 Assert(vt != NULL);
750 Assert(llvm::isa<llvm::IntegerType>(vt->getElementType()));
751
752 *nElts = (int)vt->getNumElements();
753
754 if (llvm::isa<llvm::ConstantAggregateZero>(v)) {
755 for (int i = 0; i < (int)vt->getNumElements(); ++i)
756 ret[i] = 0;
757 return true;
758 }
759
760 llvm::ConstantDataVector *cv = llvm::dyn_cast<llvm::ConstantDataVector>(v);
761 if (cv == NULL)
762 return false;
763
764 for (int i = 0; i < (int)cv->getNumElements(); ++i)
765 ret[i] = cv->getElementAsInteger(i);
766 return true;
767 }
768
769 static bool lVectorValuesAllEqual(llvm::Value *v, int vectorLength, std::vector<llvm::PHINode *> &seenPhis,
770 llvm::Value **splatValue = NULL);
771
772 /** This function checks to see if the given (scalar or vector) value is an
773 exact multiple of baseValue. It returns true if so, and false if not
774 (or if it's not able to determine if it is). Any vector value passed
775 in is required to have the same value in all elements (so that we can
776 just check the first element to be a multiple of the given value.)
777 */
lIsExactMultiple(llvm::Value * val,int baseValue,int vectorLength,std::vector<llvm::PHINode * > & seenPhis)778 static bool lIsExactMultiple(llvm::Value *val, int baseValue, int vectorLength,
779 std::vector<llvm::PHINode *> &seenPhis) {
780 if (llvm::isa<llvm::VectorType>(val->getType()) == false) {
781 // If we've worked down to a constant int, then the moment of truth
782 // has arrived...
783 llvm::ConstantInt *ci = llvm::dyn_cast<llvm::ConstantInt>(val);
784 if (ci != NULL)
785 return (ci->getZExtValue() % baseValue) == 0;
786 } else
787 Assert(LLVMVectorValuesAllEqual(val));
788
789 if (llvm::isa<llvm::InsertElementInst>(val) || llvm::isa<llvm::ShuffleVectorInst>(val)) {
790 llvm::Value *element = LLVMFlattenInsertChain(val, g->target->getVectorWidth());
791 // We just need to check the scalar first value, since we know that
792 // all elements are equal
793 return element ? lIsExactMultiple(element, baseValue, vectorLength, seenPhis) : false;
794 }
795
796 llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(val);
797 if (phi != NULL) {
798 for (unsigned int i = 0; i < seenPhis.size(); ++i)
799 if (phi == seenPhis[i])
800 return true;
801
802 seenPhis.push_back(phi);
803 unsigned int numIncoming = phi->getNumIncomingValues();
804
805 // Check all of the incoming values: if all of them pass, then
806 // we're good.
807 for (unsigned int i = 0; i < numIncoming; ++i) {
808 llvm::Value *incoming = phi->getIncomingValue(i);
809 bool mult = lIsExactMultiple(incoming, baseValue, vectorLength, seenPhis);
810 if (mult == false) {
811 seenPhis.pop_back();
812 return false;
813 }
814 }
815 seenPhis.pop_back();
816 return true;
817 }
818
819 llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(val);
820 if (bop != NULL && ((bop->getOpcode() == llvm::Instruction::Add) || IsOrEquivalentToAdd(bop))) {
821 llvm::Value *op0 = bop->getOperand(0);
822 llvm::Value *op1 = bop->getOperand(1);
823
824 bool be0 = lIsExactMultiple(op0, baseValue, vectorLength, seenPhis);
825 bool be1 = lIsExactMultiple(op1, baseValue, vectorLength, seenPhis);
826 return (be0 && be1);
827 }
828 // FIXME: mul? casts? ... ?
829
830 return false;
831 }
832
833 /** Returns the next power of two greater than or equal to the given
834 value. */
lRoundUpPow2(int v)835 static int lRoundUpPow2(int v) {
836 v--;
837 v |= v >> 1;
838 v |= v >> 2;
839 v |= v >> 4;
840 v |= v >> 8;
841 v |= v >> 16;
842 return v + 1;
843 }
844
845 /** Try to determine if all of the elements of the given vector value have
846 the same value when divided by the given baseValue. The function
847 returns true if this can be determined to be the case, and false
848 otherwise. (This function may fail to identify some cases where it
849 does in fact have this property, but should never report a given value
850 as being a multiple if it isn't!)
851 */
lAllDivBaseEqual(llvm::Value * val,int64_t baseValue,int vectorLength,std::vector<llvm::PHINode * > & seenPhis,bool & canAdd)852 static bool lAllDivBaseEqual(llvm::Value *val, int64_t baseValue, int vectorLength,
853 std::vector<llvm::PHINode *> &seenPhis, bool &canAdd) {
854 Assert(llvm::isa<llvm::VectorType>(val->getType()));
855 // Make sure the base value is a positive power of 2
856 Assert(baseValue > 0 && (baseValue & (baseValue - 1)) == 0);
857
858 // The easy case
859 if (lVectorValuesAllEqual(val, vectorLength, seenPhis))
860 return true;
861
862 int64_t vecVals[ISPC_MAX_NVEC];
863 int nElts;
864 if (llvm::isa<llvm::VectorType>(val->getType()) && LLVMExtractVectorInts(val, vecVals, &nElts)) {
865 // If we have a vector of compile-time constant integer values,
866 // then go ahead and check them directly..
867 int64_t firstDiv = vecVals[0] / baseValue;
868 for (int i = 1; i < nElts; ++i)
869 if ((vecVals[i] / baseValue) != firstDiv)
870 return false;
871
872 return true;
873 }
874
875 llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(val);
876 if (phi != NULL) {
877 for (unsigned int i = 0; i < seenPhis.size(); ++i)
878 if (phi == seenPhis[i])
879 return true;
880
881 seenPhis.push_back(phi);
882 unsigned int numIncoming = phi->getNumIncomingValues();
883
884 // Check all of the incoming values: if all of them pass, then
885 // we're good.
886 for (unsigned int i = 0; i < numIncoming; ++i) {
887 llvm::Value *incoming = phi->getIncomingValue(i);
888 bool ca = canAdd;
889 bool mult = lAllDivBaseEqual(incoming, baseValue, vectorLength, seenPhis, ca);
890 if (mult == false) {
891 seenPhis.pop_back();
892 return false;
893 }
894 }
895 seenPhis.pop_back();
896 return true;
897 }
898
899 llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(val);
900 if (bop != NULL && ((bop->getOpcode() == llvm::Instruction::Add) || IsOrEquivalentToAdd(bop)) && canAdd == true) {
901 llvm::Value *op0 = bop->getOperand(0);
902 llvm::Value *op1 = bop->getOperand(1);
903
904 // Otherwise we're only going to worry about the following case,
905 // which comes up often when looping over SOA data:
906 // ashr %val, <constant shift>
907 // where %val = add %smear, <0,1,2,3...>
908 // and where the maximum of the <0,...> vector in the add is less than
909 // 1<<(constant shift),
910 // and where %smear is a smear of a value that is a multiple of
911 // baseValue.
912
913 int64_t addConstants[ISPC_MAX_NVEC];
914 if (LLVMExtractVectorInts(op1, addConstants, &nElts) == false)
915 return false;
916 Assert(nElts == vectorLength);
917
918 // Do all of them give the same value when divided by baseValue?
919 int64_t firstConstDiv = addConstants[0] / baseValue;
920 for (int i = 1; i < vectorLength; ++i)
921 if ((addConstants[i] / baseValue) != firstConstDiv)
922 return false;
923
924 if (lVectorValuesAllEqual(op0, vectorLength, seenPhis) == false)
925 return false;
926
927 // Note that canAdd is a reference parameter; setting this ensures
928 // that we don't allow multiple adds in other parts of the chain of
929 // dependent values from here.
930 canAdd = false;
931
932 // Now we need to figure out the required alignment (in numbers of
933 // elements of the underlying type being indexed) of the value to
934 // which these integer addConstant[] values are being added to. We
935 // know that we have addConstant[] values that all give the same
936 // value when divided by baseValue, but we may need a less-strict
937 // alignment than baseValue depending on the actual values.
938 //
939 // As an example, consider a case where the baseValue alignment is
940 // 16, but the addConstants here are <0,1,2,3>. In that case, the
941 // value to which addConstants is added to only needs to be a
942 // multiple of 4. Conversely, if addConstants are <4,5,6,7>, then
943 // we need a multiple of 8 to ensure that the final added result
944 // will still have the same value for all vector elements when
945 // divided by baseValue.
946 //
947 // All that said, here we need to find the maximum value of any of
948 // the addConstants[], mod baseValue. If we round that up to the
949 // next power of 2, we'll have a value that will be no greater than
950 // baseValue and sometimes less.
951 int maxMod = int(addConstants[0] % baseValue);
952 for (int i = 1; i < vectorLength; ++i)
953 maxMod = std::max(maxMod, int(addConstants[i] % baseValue));
954 int requiredAlignment = lRoundUpPow2(maxMod);
955
956 std::vector<llvm::PHINode *> seenPhisEEM;
957 return lIsExactMultiple(op0, requiredAlignment, vectorLength, seenPhisEEM);
958 }
959 // TODO: could handle mul by a vector of equal constant integer values
960 // and the like here and adjust the 'baseValue' value when it evenly
961 // divides, but unclear if it's worthwhile...
962
963 return false;
964 }
965
966 /** Given a vector shift right of some value by some amount, try to
967 determine if all of the elements of the final result have the same
968 value (i.e. whether the high bits are all equal, disregarding the low
969 bits that are shifted out.) Returns true if so, and false otherwise.
970 */
lVectorShiftRightAllEqual(llvm::Value * val,llvm::Value * shift,int vectorLength)971 static bool lVectorShiftRightAllEqual(llvm::Value *val, llvm::Value *shift, int vectorLength) {
972 // Are we shifting all elements by a compile-time constant amount? If
973 // not, give up.
974 int64_t shiftAmount[ISPC_MAX_NVEC];
975 int nElts;
976 if (LLVMExtractVectorInts(shift, shiftAmount, &nElts) == false)
977 return false;
978 Assert(nElts == vectorLength);
979
980 // Is it the same amount for all elements?
981 for (int i = 0; i < vectorLength; ++i)
982 if (shiftAmount[i] != shiftAmount[0])
983 return false;
984
985 // Now see if the value divided by (1 << shift) can be determined to
986 // have the same value for all vector elements.
987 int pow2 = 1 << shiftAmount[0];
988 bool canAdd = true;
989 std::vector<llvm::PHINode *> seenPhis;
990 bool eq = lAllDivBaseEqual(val, pow2, vectorLength, seenPhis, canAdd);
991 #if 0
992 fprintf(stderr, "check all div base equal:\n");
993 LLVMDumpValue(shift);
994 LLVMDumpValue(val);
995 fprintf(stderr, "----> %s\n\n", eq ? "true" : "false");
996 #endif
997 return eq;
998 }
999
lVectorValuesAllEqual(llvm::Value * v,int vectorLength,std::vector<llvm::PHINode * > & seenPhis,llvm::Value ** splatValue)1000 static bool lVectorValuesAllEqual(llvm::Value *v, int vectorLength, std::vector<llvm::PHINode *> &seenPhis,
1001 llvm::Value **splatValue) {
1002 if (vectorLength == 1)
1003 return true;
1004
1005 if (llvm::isa<llvm::ConstantAggregateZero>(v)) {
1006 if (splatValue) {
1007 llvm::ConstantAggregateZero *caz = llvm::dyn_cast<llvm::ConstantAggregateZero>(v);
1008 *splatValue = caz->getSequentialElement();
1009 }
1010 return true;
1011 }
1012
1013 llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(v);
1014 if (cv != NULL) {
1015 llvm::Value *splat = cv->getSplatValue();
1016 if (splat != NULL && splatValue) {
1017 *splatValue = splat;
1018 }
1019 return (splat != NULL);
1020 }
1021
1022 llvm::ConstantDataVector *cdv = llvm::dyn_cast<llvm::ConstantDataVector>(v);
1023 if (cdv != NULL) {
1024 llvm::Value *splat = cdv->getSplatValue();
1025 if (splat != NULL && splatValue) {
1026 *splatValue = splat;
1027 }
1028 return (splat != NULL);
1029 }
1030
1031 llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
1032 if (bop != NULL) {
1033 // Easy case: both operands are all equal -> return true
1034 if (lVectorValuesAllEqual(bop->getOperand(0), vectorLength, seenPhis) &&
1035 lVectorValuesAllEqual(bop->getOperand(1), vectorLength, seenPhis))
1036 return true;
1037
1038 // If it's a shift, take a special path that tries to check if the
1039 // high (surviving) bits of the values are equal.
1040 if (bop->getOpcode() == llvm::Instruction::AShr || bop->getOpcode() == llvm::Instruction::LShr)
1041 return lVectorShiftRightAllEqual(bop->getOperand(0), bop->getOperand(1), vectorLength);
1042
1043 return false;
1044 }
1045
1046 llvm::CastInst *cast = llvm::dyn_cast<llvm::CastInst>(v);
1047 if (cast != NULL)
1048 return lVectorValuesAllEqual(cast->getOperand(0), vectorLength, seenPhis);
1049
1050 llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(v);
1051 if (ie != NULL) {
1052 return (LLVMFlattenInsertChain(ie, vectorLength) != NULL);
1053 }
1054
1055 llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(v);
1056 if (phi) {
1057 for (unsigned int i = 0; i < seenPhis.size(); ++i)
1058 if (seenPhis[i] == phi)
1059 return true;
1060
1061 seenPhis.push_back(phi);
1062
1063 unsigned int numIncoming = phi->getNumIncomingValues();
1064 // Check all of the incoming values: if all of them are all equal,
1065 // then we're good.
1066 for (unsigned int i = 0; i < numIncoming; ++i) {
1067 if (!lVectorValuesAllEqual(phi->getIncomingValue(i), vectorLength, seenPhis)) {
1068 seenPhis.pop_back();
1069 return false;
1070 }
1071 }
1072
1073 seenPhis.pop_back();
1074 return true;
1075 }
1076
1077 if (llvm::isa<llvm::UndefValue>(v))
1078 // ?
1079 return false;
1080
1081 Assert(!llvm::isa<llvm::Constant>(v));
1082
1083 if (llvm::isa<llvm::CallInst>(v) || llvm::isa<llvm::LoadInst>(v) || !llvm::isa<llvm::Instruction>(v))
1084 return false;
1085
1086 llvm::ShuffleVectorInst *shuffle = llvm::dyn_cast<llvm::ShuffleVectorInst>(v);
1087 if (shuffle != NULL) {
1088 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
1089 llvm::Value *indices = shuffle->getShuffleMaskForBitcode();
1090 #else
1091 llvm::Value *indices = shuffle->getOperand(2);
1092 #endif
1093
1094 if (lVectorValuesAllEqual(indices, vectorLength, seenPhis))
1095 // The easy case--just a smear of the same element across the
1096 // whole vector.
1097 return true;
1098
1099 // TODO: handle more general cases?
1100 return false;
1101 }
1102
1103 #if 0
1104 fprintf(stderr, "all equal: ");
1105 v->dump();
1106 fprintf(stderr, "\n");
1107 llvm::Instruction *inst = llvm::dyn_cast<llvm::Instruction>(v);
1108 if (inst) {
1109 inst->getParent()->dump();
1110 fprintf(stderr, "\n");
1111 fprintf(stderr, "\n");
1112 }
1113 #endif
1114
1115 return false;
1116 }
1117
1118 /** Tests to see if all of the elements of the vector in the 'v' parameter
1119 are equal. This is a conservative test and may return false for arrays
1120 where the values are actually all equal.
1121 */
LLVMVectorValuesAllEqual(llvm::Value * v,llvm::Value ** splat)1122 bool LLVMVectorValuesAllEqual(llvm::Value *v, llvm::Value **splat) {
1123 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
1124 llvm::FixedVectorType *vt = llvm::dyn_cast<llvm::FixedVectorType>(v->getType());
1125 #else
1126 llvm::VectorType *vt = llvm::dyn_cast<llvm::VectorType>(v->getType());
1127 #endif
1128 Assert(vt != NULL);
1129 int vectorLength = vt->getNumElements();
1130
1131 std::vector<llvm::PHINode *> seenPhis;
1132 bool equal = lVectorValuesAllEqual(v, vectorLength, seenPhis, splat);
1133
1134 Debug(SourcePos(), "LLVMVectorValuesAllEqual(%s) -> %s.", v->getName().str().c_str(), equal ? "true" : "false");
1135 #ifndef ISPC_NO_DUMPS
1136 if (g->debugPrint)
1137 LLVMDumpValue(v);
1138 #endif
1139
1140 return equal;
1141 }
1142
1143 /** Tests to see if a binary operator has an OR which is equivalent to an ADD.*/
IsOrEquivalentToAdd(llvm::Value * op)1144 bool IsOrEquivalentToAdd(llvm::Value *op) {
1145 bool isEq = false;
1146 llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(op);
1147 if (bop != NULL && bop->getOpcode() == llvm::Instruction::Or) {
1148 // Special case when A+B --> A|B transformation is triggered
1149 // We need to prove that A|B == A+B
1150 llvm::Module *module = bop->getParent()->getParent()->getParent();
1151 llvm::Value *op0 = bop->getOperand(0), *op1 = bop->getOperand(1);
1152 if (!haveNoCommonBitsSet(op0, op1, module->getDataLayout()) == false) {
1153 // Fallback to A+B case
1154 isEq = true;
1155 }
1156 }
1157 return isEq;
1158 }
1159
1160 static bool lVectorIsLinear(llvm::Value *v, int vectorLength, int stride, std::vector<llvm::PHINode *> &seenPhis);
1161
1162 /** Given a vector of compile-time constant integer values, test to see if
1163 they are a linear sequence of constant integers starting from an
1164 arbirary value but then having a step of value "stride" between
1165 elements.
1166 */
lVectorIsLinearConstantInts(llvm::ConstantDataVector * cv,int vectorLength,int stride)1167 static bool lVectorIsLinearConstantInts(llvm::ConstantDataVector *cv, int vectorLength, int stride) {
1168 // Flatten the vector out into the elements array
1169 llvm::SmallVector<llvm::Constant *, ISPC_MAX_NVEC> elements;
1170 for (int i = 0; i < (int)cv->getNumElements(); ++i)
1171 elements.push_back(cv->getElementAsConstant(i));
1172 Assert((int)elements.size() == vectorLength);
1173
1174 llvm::ConstantInt *ci = llvm::dyn_cast<llvm::ConstantInt>(elements[0]);
1175 if (ci == NULL)
1176 // Not a vector of integers
1177 return false;
1178
1179 int64_t prevVal = ci->getSExtValue();
1180
1181 // For each element in the array, see if it is both a ConstantInt and
1182 // if the difference between it and the value of the previous element
1183 // is stride. If not, fail.
1184 for (int i = 1; i < vectorLength; ++i) {
1185 ci = llvm::dyn_cast<llvm::ConstantInt>(elements[i]);
1186 if (ci == NULL)
1187 return false;
1188
1189 int64_t nextVal = ci->getSExtValue();
1190 if (prevVal + stride != nextVal)
1191 return false;
1192
1193 prevVal = nextVal;
1194 }
1195 return true;
1196 }
1197
1198 /** Checks to see if (op0 * op1) is a linear vector where the result is a
1199 vector with values that increase by stride.
1200 */
lCheckMulForLinear(llvm::Value * op0,llvm::Value * op1,int vectorLength,int stride,std::vector<llvm::PHINode * > & seenPhis)1201 static bool lCheckMulForLinear(llvm::Value *op0, llvm::Value *op1, int vectorLength, int stride,
1202 std::vector<llvm::PHINode *> &seenPhis) {
1203 // Is the first operand a constant integer value splatted across all of
1204 // the lanes?
1205 llvm::ConstantDataVector *cv = llvm::dyn_cast<llvm::ConstantDataVector>(op0);
1206 if (cv == NULL)
1207 return false;
1208
1209 llvm::Constant *csplat = cv->getSplatValue();
1210 if (csplat == NULL)
1211 return false;
1212
1213 llvm::ConstantInt *splat = llvm::dyn_cast<llvm::ConstantInt>(csplat);
1214 if (splat == NULL)
1215 return false;
1216
1217 // If the splat value doesn't evenly divide the stride we're looking
1218 // for, there's no way that we can get the linear sequence we're
1219 // looking or.
1220 int64_t splatVal = splat->getSExtValue();
1221 if (splatVal == 0 || splatVal > stride || (stride % splatVal) != 0)
1222 return false;
1223
1224 // Check to see if the other operand is a linear vector with stride
1225 // given by stride/splatVal.
1226 return lVectorIsLinear(op1, vectorLength, (int)(stride / splatVal), seenPhis);
1227 }
1228
1229 /** Checks to see if (op0 << op1) is a linear vector where the result is a
1230 vector with values that increase by stride.
1231 */
lCheckShlForLinear(llvm::Value * op0,llvm::Value * op1,int vectorLength,int stride,std::vector<llvm::PHINode * > & seenPhis)1232 static bool lCheckShlForLinear(llvm::Value *op0, llvm::Value *op1, int vectorLength, int stride,
1233 std::vector<llvm::PHINode *> &seenPhis) {
1234 // Is the second operand a constant integer value splatted across all of
1235 // the lanes?
1236 llvm::ConstantDataVector *cv = llvm::dyn_cast<llvm::ConstantDataVector>(op1);
1237 if (cv == NULL)
1238 return false;
1239
1240 llvm::Constant *csplat = cv->getSplatValue();
1241 if (csplat == NULL)
1242 return false;
1243
1244 llvm::ConstantInt *splat = llvm::dyn_cast<llvm::ConstantInt>(csplat);
1245 if (splat == NULL)
1246 return false;
1247
1248 // If (1 << the splat value) doesn't evenly divide the stride we're
1249 // looking for, there's no way that we can get the linear sequence
1250 // we're looking or.
1251 int64_t equivalentMul = (1LL << splat->getSExtValue());
1252 if (equivalentMul > stride || (stride % equivalentMul) != 0)
1253 return false;
1254
1255 // Check to see if the first operand is a linear vector with stride
1256 // given by stride/splatVal.
1257 return lVectorIsLinear(op0, vectorLength, (int)(stride / equivalentMul), seenPhis);
1258 }
1259
1260 /** Given (op0 AND op1), try and see if we can determine if the result is a
1261 linear sequence with a step of "stride" between values. Returns true
1262 if so and false otherwise. This pattern comes up when accessing SOA
1263 data.
1264 */
lCheckAndForLinear(llvm::Value * op0,llvm::Value * op1,int vectorLength,int stride,std::vector<llvm::PHINode * > & seenPhis)1265 static bool lCheckAndForLinear(llvm::Value *op0, llvm::Value *op1, int vectorLength, int stride,
1266 std::vector<llvm::PHINode *> &seenPhis) {
1267 // Require op1 to be a compile-time constant
1268 int64_t maskValue[ISPC_MAX_NVEC];
1269 int nElts;
1270 if (LLVMExtractVectorInts(op1, maskValue, &nElts) == false)
1271 return false;
1272 Assert(nElts == vectorLength);
1273
1274 // Is op1 a smear of the same value across all lanes? Give up if not.
1275 for (int i = 1; i < vectorLength; ++i)
1276 if (maskValue[i] != maskValue[0])
1277 return false;
1278
1279 // If the op1 value isn't power of 2 minus one, then also give up.
1280 int64_t maskPlusOne = maskValue[0] + 1;
1281 bool isPowTwo = (maskPlusOne & (maskPlusOne - 1)) == 0;
1282 if (isPowTwo == false)
1283 return false;
1284
1285 // The case we'll covert here is op0 being a linear vector with desired
1286 // stride, and where all of the values of op0, when divided by
1287 // maskPlusOne, have the same value.
1288 if (lVectorIsLinear(op0, vectorLength, stride, seenPhis) == false)
1289 return false;
1290
1291 bool canAdd = true;
1292 bool isMult = lAllDivBaseEqual(op0, maskPlusOne, vectorLength, seenPhis, canAdd);
1293 return isMult;
1294 }
1295
lVectorIsLinear(llvm::Value * v,int vectorLength,int stride,std::vector<llvm::PHINode * > & seenPhis)1296 static bool lVectorIsLinear(llvm::Value *v, int vectorLength, int stride, std::vector<llvm::PHINode *> &seenPhis) {
1297 // First try the easy case: if the values are all just constant
1298 // integers and have the expected stride between them, then we're done.
1299 llvm::ConstantDataVector *cv = llvm::dyn_cast<llvm::ConstantDataVector>(v);
1300 if (cv != NULL)
1301 return lVectorIsLinearConstantInts(cv, vectorLength, stride);
1302
1303 llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
1304 if (bop != NULL) {
1305 // FIXME: is it right to pass the seenPhis to the all equal check as well??
1306 llvm::Value *op0 = bop->getOperand(0), *op1 = bop->getOperand(1);
1307 if ((bop->getOpcode() == llvm::Instruction::Add) || IsOrEquivalentToAdd(bop)) {
1308 // There are two cases to check if we have an add:
1309 //
1310 // programIndex + unif -> ascending linear seqeuence
1311 // unif + programIndex -> ascending linear sequence
1312 bool l0 = lVectorIsLinear(op0, vectorLength, stride, seenPhis);
1313 bool e1 = lVectorValuesAllEqual(op1, vectorLength, seenPhis);
1314 if (l0 && e1)
1315 return true;
1316
1317 bool e0 = lVectorValuesAllEqual(op0, vectorLength, seenPhis);
1318 bool l1 = lVectorIsLinear(op1, vectorLength, stride, seenPhis);
1319 return (e0 && l1);
1320 } else if (bop->getOpcode() == llvm::Instruction::Sub)
1321 // For subtraction, we only match:
1322 // programIndex - unif -> ascending linear seqeuence
1323 return (lVectorIsLinear(bop->getOperand(0), vectorLength, stride, seenPhis) &&
1324 lVectorValuesAllEqual(bop->getOperand(1), vectorLength, seenPhis));
1325 else if (bop->getOpcode() == llvm::Instruction::Mul) {
1326 // Multiplies are a bit trickier, so are handled in a separate
1327 // function.
1328 bool m0 = lCheckMulForLinear(op0, op1, vectorLength, stride, seenPhis);
1329 if (m0)
1330 return true;
1331 bool m1 = lCheckMulForLinear(op1, op0, vectorLength, stride, seenPhis);
1332 return m1;
1333 } else if (bop->getOpcode() == llvm::Instruction::Shl) {
1334 // Sometimes multiplies come in as shift lefts (especially in
1335 // LLVM 3.4+).
1336 bool linear = lCheckShlForLinear(op0, op1, vectorLength, stride, seenPhis);
1337 return linear;
1338 } else if (bop->getOpcode() == llvm::Instruction::And) {
1339 // Special case for some AND-related patterns that come up when
1340 // looping over SOA data
1341 bool linear = lCheckAndForLinear(op0, op1, vectorLength, stride, seenPhis);
1342 return linear;
1343 } else
1344 return false;
1345 }
1346
1347 llvm::CastInst *ci = llvm::dyn_cast<llvm::CastInst>(v);
1348 if (ci != NULL)
1349 return lVectorIsLinear(ci->getOperand(0), vectorLength, stride, seenPhis);
1350
1351 if (llvm::isa<llvm::CallInst>(v) || llvm::isa<llvm::LoadInst>(v))
1352 return false;
1353
1354 llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(v);
1355 if (phi != NULL) {
1356 for (unsigned int i = 0; i < seenPhis.size(); ++i)
1357 if (seenPhis[i] == phi)
1358 return true;
1359
1360 seenPhis.push_back(phi);
1361
1362 unsigned int numIncoming = phi->getNumIncomingValues();
1363 // Check all of the incoming values: if all of them are all equal,
1364 // then we're good.
1365 for (unsigned int i = 0; i < numIncoming; ++i) {
1366 if (!lVectorIsLinear(phi->getIncomingValue(i), vectorLength, stride, seenPhis)) {
1367 seenPhis.pop_back();
1368 return false;
1369 }
1370 }
1371
1372 seenPhis.pop_back();
1373 return true;
1374 }
1375
1376 // TODO: is any reason to worry about these?
1377 if (llvm::isa<llvm::InsertElementInst>(v))
1378 return false;
1379
1380 // TODO: we could also handle shuffles, but we haven't yet seen any
1381 // cases where doing so would detect cases where actually have a linear
1382 // vector.
1383 llvm::ShuffleVectorInst *shuffle = llvm::dyn_cast<llvm::ShuffleVectorInst>(v);
1384 if (shuffle != NULL)
1385 return false;
1386
1387 #if 0
1388 fprintf(stderr, "linear check: ");
1389 v->dump();
1390 fprintf(stderr, "\n");
1391 llvm::Instruction *inst = llvm::dyn_cast<llvm::Instruction>(v);
1392 if (inst) {
1393 inst->getParent()->dump();
1394 fprintf(stderr, "\n");
1395 fprintf(stderr, "\n");
1396 }
1397 #endif
1398
1399 return false;
1400 }
1401
1402 /** Given vector of integer-typed values, see if the elements of the array
1403 have a step of 'stride' between their values. This function tries to
1404 handle as many possibilities as possible, including things like all
1405 elements equal to some non-constant value plus an integer offset, etc.
1406 */
LLVMVectorIsLinear(llvm::Value * v,int stride)1407 bool LLVMVectorIsLinear(llvm::Value *v, int stride) {
1408 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
1409 llvm::FixedVectorType *vt = llvm::dyn_cast<llvm::FixedVectorType>(v->getType());
1410 #else
1411 llvm::VectorType *vt = llvm::dyn_cast<llvm::VectorType>(v->getType());
1412 #endif
1413 Assert(vt != NULL);
1414 int vectorLength = vt->getNumElements();
1415
1416 std::vector<llvm::PHINode *> seenPhis;
1417 bool linear = lVectorIsLinear(v, vectorLength, stride, seenPhis);
1418 Debug(SourcePos(), "LLVMVectorIsLinear(%s) -> %s.", v->getName().str().c_str(), linear ? "true" : "false");
1419 #ifndef ISPC_NO_DUMPS
1420 if (g->debugPrint)
1421 LLVMDumpValue(v);
1422 #endif
1423
1424 return linear;
1425 }
1426
1427 #ifndef ISPC_NO_DUMPS
lDumpValue(llvm::Value * v,std::set<llvm::Value * > & done)1428 static void lDumpValue(llvm::Value *v, std::set<llvm::Value *> &done) {
1429 if (done.find(v) != done.end())
1430 return;
1431
1432 llvm::Instruction *inst = llvm::dyn_cast<llvm::Instruction>(v);
1433 if (done.size() > 0 && inst == NULL)
1434 return;
1435
1436 fprintf(stderr, " ");
1437 //v->dump();
1438 done.insert(v);
1439
1440 if (inst == NULL)
1441 return;
1442
1443 for (unsigned i = 0; i < inst->getNumOperands(); ++i)
1444 lDumpValue(inst->getOperand(i), done);
1445 }
1446
LLVMDumpValue(llvm::Value * v)1447 void LLVMDumpValue(llvm::Value *v) {
1448 std::set<llvm::Value *> done;
1449 lDumpValue(v, done);
1450 fprintf(stderr, "----\n");
1451 }
1452 #endif
1453
lExtractFirstVectorElement(llvm::Value * v,std::map<llvm::PHINode *,llvm::PHINode * > & phiMap)1454 static llvm::Value *lExtractFirstVectorElement(llvm::Value *v, std::map<llvm::PHINode *, llvm::PHINode *> &phiMap) {
1455 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
1456 llvm::FixedVectorType *vt = llvm::dyn_cast<llvm::FixedVectorType>(v->getType());
1457 #else
1458 llvm::VectorType *vt = llvm::dyn_cast<llvm::VectorType>(v->getType());
1459 #endif
1460 Assert(vt != NULL);
1461
1462 // First, handle various constant types; do the extraction manually, as
1463 // appropriate.
1464 if (llvm::isa<llvm::ConstantAggregateZero>(v) == true) {
1465 return llvm::Constant::getNullValue(vt->getElementType());
1466 }
1467 if (llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(v)) {
1468 return cv->getOperand(0);
1469 }
1470 if (llvm::ConstantDataVector *cdv = llvm::dyn_cast<llvm::ConstantDataVector>(v))
1471 return cdv->getElementAsConstant(0);
1472
1473 // Otherwise, all that we should have at this point is an instruction
1474 // of some sort
1475 Assert(llvm::isa<llvm::Constant>(v) == false);
1476 Assert(llvm::isa<llvm::Instruction>(v) == true);
1477
1478 std::string newName = v->getName().str() + std::string(".elt0");
1479
1480 // Rewrite regular binary operators and casts to the scalarized
1481 // equivalent.
1482 llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
1483 if (bop != NULL) {
1484 llvm::Value *v0 = lExtractFirstVectorElement(bop->getOperand(0), phiMap);
1485 llvm::Value *v1 = lExtractFirstVectorElement(bop->getOperand(1), phiMap);
1486 Assert(v0 != NULL);
1487 Assert(v1 != NULL);
1488 // Note that the new binary operator is inserted immediately before
1489 // the previous vector one
1490 return llvm::BinaryOperator::Create(bop->getOpcode(), v0, v1, newName, bop);
1491 }
1492
1493 llvm::CastInst *cast = llvm::dyn_cast<llvm::CastInst>(v);
1494 if (cast != NULL) {
1495 llvm::Value *v = lExtractFirstVectorElement(cast->getOperand(0), phiMap);
1496 // Similarly, the equivalent scalar cast instruction goes right
1497 // before the vector cast
1498 return llvm::CastInst::Create(cast->getOpcode(), v, vt->getElementType(), newName, cast);
1499 }
1500
1501 llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(v);
1502 if (phi != NULL) {
1503 // For PHI notes, recursively scalarize them.
1504 if (phiMap.find(phi) != phiMap.end())
1505 return phiMap[phi];
1506
1507 // We need to create the new scalar PHI node immediately, though,
1508 // and put it in the map<>, so that if we come back to this node
1509 // via a recursive lExtractFirstVectorElement() call, then we can
1510 // return the pointer and not get stuck in an infinite loop.
1511 //
1512 // The insertion point for the new phi node also has to be the
1513 // start of the bblock of the original phi node.
1514
1515 llvm::Instruction *phiInsertPos = &*(phi->getParent()->begin());
1516 llvm::PHINode *scalarPhi =
1517 llvm::PHINode::Create(vt->getElementType(), phi->getNumIncomingValues(), newName, phiInsertPos);
1518 phiMap[phi] = scalarPhi;
1519
1520 for (unsigned i = 0; i < phi->getNumIncomingValues(); ++i) {
1521 llvm::Value *v = lExtractFirstVectorElement(phi->getIncomingValue(i), phiMap);
1522 scalarPhi->addIncoming(v, phi->getIncomingBlock(i));
1523 }
1524
1525 return scalarPhi;
1526 }
1527
1528 // We should consider "shuffle" case and "insertElement" case separately.
1529 // For example we can have shuffle(mul, undef, zero) but function
1530 // "LLVMFlattenInsertChain" can handle only case shuffle(insertElement, undef, zero).
1531 // Also if we have insertElement under shuffle we will handle it the next call of
1532 // "lExtractFirstVectorElement" function.
1533 if (llvm::isa<llvm::ShuffleVectorInst>(v)) {
1534 llvm::ShuffleVectorInst *shuf = llvm::dyn_cast<llvm::ShuffleVectorInst>(v);
1535 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
1536 llvm::Value *indices = shuf->getShuffleMaskForBitcode();
1537 #else
1538 llvm::Value *indices = shuf->getOperand(2);
1539 #endif
1540 if (llvm::isa<llvm::ConstantAggregateZero>(indices)) {
1541 return lExtractFirstVectorElement(shuf->getOperand(0), phiMap);
1542 }
1543 }
1544
1545 // If we have a chain of insertelement instructions, then we can just
1546 // flatten them out and grab the value for the first one.
1547 if (llvm::isa<llvm::InsertElementInst>(v)) {
1548 return LLVMFlattenInsertChain(v, vt->getNumElements(), false);
1549 }
1550
1551 // Worst case, for everything else, just do a regular extract element
1552 // instruction, which we insert immediately after the instruction we
1553 // have here.
1554 llvm::Instruction *insertAfter = llvm::dyn_cast<llvm::Instruction>(v);
1555 Assert(insertAfter != NULL);
1556 llvm::Instruction *ee = llvm::ExtractElementInst::Create(v, LLVMInt32(0), "first_elt", (llvm::Instruction *)NULL);
1557 ee->insertAfter(insertAfter);
1558 return ee;
1559 }
1560
LLVMExtractFirstVectorElement(llvm::Value * v)1561 llvm::Value *LLVMExtractFirstVectorElement(llvm::Value *v) {
1562 std::map<llvm::PHINode *, llvm::PHINode *> phiMap;
1563 llvm::Value *ret = lExtractFirstVectorElement(v, phiMap);
1564 return ret;
1565 }
1566
1567 /** Given two vectors of the same type, concatenate them into a vector that
1568 has twice as many elements, where the first half has the elements from
1569 the first vector and the second half has the elements from the second
1570 vector.
1571 */
LLVMConcatVectors(llvm::Value * v1,llvm::Value * v2,llvm::Instruction * insertBefore)1572 llvm::Value *LLVMConcatVectors(llvm::Value *v1, llvm::Value *v2, llvm::Instruction *insertBefore) {
1573 Assert(v1->getType() == v2->getType());
1574 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
1575 llvm::FixedVectorType *vt = llvm::dyn_cast<llvm::FixedVectorType>(v1->getType());
1576 #else
1577 llvm::VectorType *vt = llvm::dyn_cast<llvm::VectorType>(v1->getType());
1578 #endif
1579 Assert(vt != NULL);
1580
1581 int32_t identity[ISPC_MAX_NVEC];
1582 int resultSize = 2 * vt->getNumElements();
1583 Assert(resultSize <= ISPC_MAX_NVEC);
1584 for (int i = 0; i < resultSize; ++i)
1585 identity[i] = i;
1586
1587 return LLVMShuffleVectors(v1, v2, identity, resultSize, insertBefore);
1588 }
1589
1590 /** Shuffle two vectors together with a ShuffleVectorInst, returning a
1591 vector with shufSize elements, where the shuf[] array offsets are used
1592 to determine which element from the two given vectors is used for each
1593 result element. */
LLVMShuffleVectors(llvm::Value * v1,llvm::Value * v2,int32_t shuf[],int shufSize,llvm::Instruction * insertBefore)1594 llvm::Value *LLVMShuffleVectors(llvm::Value *v1, llvm::Value *v2, int32_t shuf[], int shufSize,
1595 llvm::Instruction *insertBefore) {
1596 std::vector<llvm::Constant *> shufVec;
1597 for (int i = 0; i < shufSize; ++i) {
1598 if (shuf[i] == -1)
1599 shufVec.push_back(llvm::UndefValue::get(LLVMTypes::Int32Type));
1600 else
1601 shufVec.push_back(LLVMInt32(shuf[i]));
1602 }
1603
1604 llvm::ArrayRef<llvm::Constant *> aref(&shufVec[0], &shufVec[shufSize]);
1605 llvm::Value *vec = llvm::ConstantVector::get(aref);
1606
1607 return new llvm::ShuffleVectorInst(v1, v2, vec, "shuffle", insertBefore);
1608 }
1609
1610 #ifdef ISPC_GENX_ENABLED
lIsSVMLoad(llvm::Instruction * inst)1611 bool lIsSVMLoad(llvm::Instruction *inst) {
1612 Assert(inst);
1613
1614 switch (llvm::GenXIntrinsic::getGenXIntrinsicID(inst)) {
1615 case llvm::GenXIntrinsic::genx_svm_block_ld:
1616 case llvm::GenXIntrinsic::genx_svm_block_ld_unaligned:
1617 case llvm::GenXIntrinsic::genx_svm_gather:
1618 return true;
1619 default:
1620 return false;
1621 }
1622 }
1623
lGetAddressSpace(llvm::Value * v,std::set<llvm::Value * > & done,std::set<AddressSpace> & addrSpaceVec)1624 void lGetAddressSpace(llvm::Value *v, std::set<llvm::Value *> &done, std::set<AddressSpace> &addrSpaceVec) {
1625 if (done.find(v) != done.end()) {
1626 if (llvm::isa<llvm::PointerType>(v->getType()))
1627 addrSpaceVec.insert(AddressSpace::External);
1628 return;
1629 }
1630 // Found global value
1631 if (llvm::isa<llvm::GlobalValue>(v)) {
1632 addrSpaceVec.insert(AddressSpace::Global);
1633 return;
1634 }
1635
1636 llvm::Instruction *inst = llvm::dyn_cast<llvm::Instruction>(v);
1637 bool isConstExpr = false;
1638 if (inst == NULL) {
1639 // Case when GEP is constant expression
1640 if (llvm::ConstantExpr *constExpr = llvm::dyn_cast<llvm::ConstantExpr>(v)) {
1641 // This instruction isn't inserted anywhere, so delete it when done
1642 inst = constExpr->getAsInstruction();
1643 isConstExpr = true;
1644 }
1645 }
1646
1647 if (done.size() > 0 && inst == NULL) {
1648 // Found external pointer like "float* %aFOO"
1649 if (llvm::isa<llvm::PointerType>(v->getType()))
1650 addrSpaceVec.insert(AddressSpace::External);
1651 return;
1652 }
1653
1654 done.insert(v);
1655
1656 // Found value allocated on stack like "%val = alloca [16 x float]"
1657 if (llvm::isa<llvm::AllocaInst>(v)) {
1658 addrSpaceVec.insert(AddressSpace::Local);
1659 return;
1660 }
1661
1662 if (inst == NULL || llvm::isa<llvm::CallInst>(v)) {
1663 if (llvm::isa<llvm::PointerType>(v->getType()) || (inst && lIsSVMLoad(inst)))
1664 addrSpaceVec.insert(AddressSpace::External);
1665 return;
1666 }
1667 // For GEP we check only pointer operand, for the rest check all
1668 if (llvm::isa<llvm::GetElementPtrInst>(inst)) {
1669 lGetAddressSpace(inst->getOperand(0), done, addrSpaceVec);
1670 } else {
1671 for (unsigned i = 0; i < inst->getNumOperands(); ++i) {
1672 lGetAddressSpace(inst->getOperand(i), done, addrSpaceVec);
1673 }
1674 }
1675
1676 if (isConstExpr) {
1677 // This is the only return point that constant expression instruction
1678 // can reach, drop all references here
1679 inst->dropAllReferences();
1680 }
1681 }
1682
1683 /** This routine attempts to determine if the given value is pointing to
1684 stack-allocated memory. The basic strategy is to traverse through the
1685 operands and see if the pointer originally comes from an AllocaInst.
1686 */
GetAddressSpace(llvm::Value * v)1687 AddressSpace GetAddressSpace(llvm::Value *v) {
1688 std::set<llvm::Value *> done;
1689 std::set<AddressSpace> addrSpaceVec;
1690 lGetAddressSpace(v, done, addrSpaceVec);
1691 if (addrSpaceVec.find(AddressSpace::External) != addrSpaceVec.end()) {
1692 return AddressSpace::External;
1693 }
1694 if (addrSpaceVec.find(AddressSpace::Global) != addrSpaceVec.end()) {
1695 return AddressSpace::Global;
1696 }
1697 return AddressSpace::Local;
1698 }
1699 #endif
1700 } // namespace ispc
1701