1 /*
2   Copyright (c) 2010-2021, Intel Corporation
3   All rights reserved.
4 
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions are
7   met:
8 
9     * Redistributions of source code must retain the above copyright
10       notice, this list of conditions and the following disclaimer.
11 
12     * Redistributions in binary form must reproduce the above copyright
13       notice, this list of conditions and the following disclaimer in the
14       documentation and/or other materials provided with the distribution.
15 
16     * Neither the name of Intel Corporation nor the names of its
17       contributors may be used to endorse or promote products derived from
18       this software without specific prior written permission.
19 
20 
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
22    IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
23    TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
24    PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
25    OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33 
34 /** @file llvmutil.cpp
35     @brief Implementations of various LLVM utility types and classes.
36 */
37 
38 #include "llvmutil.h"
39 #include "ispc.h"
40 #include "type.h"
41 
42 #include <map>
43 #include <set>
44 #include <vector>
45 
46 #include <llvm/Analysis/ValueTracking.h>
47 #include <llvm/IR/BasicBlock.h>
48 #include <llvm/IR/Instructions.h>
49 #include <llvm/IR/Module.h>
50 
51 #ifdef ISPC_GENX_ENABLED
52 #include <llvm/GenXIntrinsics/GenXIntrinsics.h>
53 #endif
54 
55 namespace ispc {
56 
57 llvm::Type *LLVMTypes::VoidType = NULL;
58 llvm::PointerType *LLVMTypes::VoidPointerType = NULL;
59 llvm::Type *LLVMTypes::PointerIntType = NULL;
60 llvm::Type *LLVMTypes::BoolType = NULL;
61 llvm::Type *LLVMTypes::BoolStorageType = NULL;
62 
63 llvm::Type *LLVMTypes::Int8Type = NULL;
64 llvm::Type *LLVMTypes::Int16Type = NULL;
65 llvm::Type *LLVMTypes::Int32Type = NULL;
66 llvm::Type *LLVMTypes::Int64Type = NULL;
67 llvm::Type *LLVMTypes::FloatType = NULL;
68 llvm::Type *LLVMTypes::DoubleType = NULL;
69 
70 llvm::Type *LLVMTypes::Int8PointerType = NULL;
71 llvm::Type *LLVMTypes::Int16PointerType = NULL;
72 llvm::Type *LLVMTypes::Int32PointerType = NULL;
73 llvm::Type *LLVMTypes::Int64PointerType = NULL;
74 llvm::Type *LLVMTypes::FloatPointerType = NULL;
75 llvm::Type *LLVMTypes::DoublePointerType = NULL;
76 
77 llvm::VectorType *LLVMTypes::MaskType = NULL;
78 llvm::VectorType *LLVMTypes::BoolVectorType = NULL;
79 llvm::VectorType *LLVMTypes::BoolVectorStorageType = NULL;
80 
81 llvm::VectorType *LLVMTypes::Int1VectorType = NULL;
82 llvm::VectorType *LLVMTypes::Int8VectorType = NULL;
83 llvm::VectorType *LLVMTypes::Int16VectorType = NULL;
84 llvm::VectorType *LLVMTypes::Int32VectorType = NULL;
85 llvm::VectorType *LLVMTypes::Int64VectorType = NULL;
86 llvm::VectorType *LLVMTypes::FloatVectorType = NULL;
87 llvm::VectorType *LLVMTypes::DoubleVectorType = NULL;
88 
89 llvm::Type *LLVMTypes::Int8VectorPointerType = NULL;
90 llvm::Type *LLVMTypes::Int16VectorPointerType = NULL;
91 llvm::Type *LLVMTypes::Int32VectorPointerType = NULL;
92 llvm::Type *LLVMTypes::Int64VectorPointerType = NULL;
93 llvm::Type *LLVMTypes::FloatVectorPointerType = NULL;
94 llvm::Type *LLVMTypes::DoubleVectorPointerType = NULL;
95 
96 llvm::VectorType *LLVMTypes::VoidPointerVectorType = NULL;
97 
98 llvm::Constant *LLVMTrue = NULL;
99 llvm::Constant *LLVMFalse = NULL;
100 llvm::Constant *LLVMTrueInStorage = NULL;
101 llvm::Constant *LLVMFalseInStorage = NULL;
102 llvm::Constant *LLVMMaskAllOn = NULL;
103 llvm::Constant *LLVMMaskAllOff = NULL;
104 
InitLLVMUtil(llvm::LLVMContext * ctx,Target & target)105 void InitLLVMUtil(llvm::LLVMContext *ctx, Target &target) {
106     LLVMTypes::VoidType = llvm::Type::getVoidTy(*ctx);
107     LLVMTypes::VoidPointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(*ctx), 0);
108     LLVMTypes::PointerIntType = target.is32Bit() ? llvm::Type::getInt32Ty(*ctx) : llvm::Type::getInt64Ty(*ctx);
109 
110     LLVMTypes::BoolType = llvm::Type::getInt1Ty(*ctx);
111     LLVMTypes::Int8Type = LLVMTypes::BoolStorageType = llvm::Type::getInt8Ty(*ctx);
112     LLVMTypes::Int16Type = llvm::Type::getInt16Ty(*ctx);
113     LLVMTypes::Int32Type = llvm::Type::getInt32Ty(*ctx);
114     LLVMTypes::Int64Type = llvm::Type::getInt64Ty(*ctx);
115     LLVMTypes::FloatType = llvm::Type::getFloatTy(*ctx);
116     LLVMTypes::DoubleType = llvm::Type::getDoubleTy(*ctx);
117 
118     LLVMTypes::Int8PointerType = llvm::PointerType::get(LLVMTypes::Int8Type, 0);
119     LLVMTypes::Int16PointerType = llvm::PointerType::get(LLVMTypes::Int16Type, 0);
120     LLVMTypes::Int32PointerType = llvm::PointerType::get(LLVMTypes::Int32Type, 0);
121     LLVMTypes::Int64PointerType = llvm::PointerType::get(LLVMTypes::Int64Type, 0);
122     LLVMTypes::FloatPointerType = llvm::PointerType::get(LLVMTypes::FloatType, 0);
123     LLVMTypes::DoublePointerType = llvm::PointerType::get(LLVMTypes::DoubleType, 0);
124 
125     switch (target.getMaskBitCount()) {
126     case 1:
127         LLVMTypes::MaskType = LLVMTypes::BoolVectorType =
128             LLVMVECTOR::get(llvm::Type::getInt1Ty(*ctx), target.getVectorWidth());
129         break;
130     case 8:
131         LLVMTypes::MaskType = LLVMTypes::BoolVectorType =
132             LLVMVECTOR::get(llvm::Type::getInt8Ty(*ctx), target.getVectorWidth());
133         break;
134     case 16:
135         LLVMTypes::MaskType = LLVMTypes::BoolVectorType =
136             LLVMVECTOR::get(llvm::Type::getInt16Ty(*ctx), target.getVectorWidth());
137         break;
138     case 32:
139         LLVMTypes::MaskType = LLVMTypes::BoolVectorType =
140             LLVMVECTOR::get(llvm::Type::getInt32Ty(*ctx), target.getVectorWidth());
141         break;
142     case 64:
143         LLVMTypes::MaskType = LLVMTypes::BoolVectorType =
144             LLVMVECTOR::get(llvm::Type::getInt64Ty(*ctx), target.getVectorWidth());
145         break;
146     default:
147         FATAL("Unhandled mask width for initializing MaskType");
148     }
149 
150     LLVMTypes::Int1VectorType = LLVMVECTOR::get(llvm::Type::getInt1Ty(*ctx), target.getVectorWidth());
151     LLVMTypes::Int8VectorType = LLVMTypes::BoolVectorStorageType =
152         LLVMVECTOR::get(LLVMTypes::Int8Type, target.getVectorWidth());
153     LLVMTypes::Int16VectorType = LLVMVECTOR::get(LLVMTypes::Int16Type, target.getVectorWidth());
154     LLVMTypes::Int32VectorType = LLVMVECTOR::get(LLVMTypes::Int32Type, target.getVectorWidth());
155     LLVMTypes::Int64VectorType = LLVMVECTOR::get(LLVMTypes::Int64Type, target.getVectorWidth());
156     LLVMTypes::FloatVectorType = LLVMVECTOR::get(LLVMTypes::FloatType, target.getVectorWidth());
157     LLVMTypes::DoubleVectorType = LLVMVECTOR::get(LLVMTypes::DoubleType, target.getVectorWidth());
158 
159     LLVMTypes::Int8VectorPointerType = llvm::PointerType::get(LLVMTypes::Int8VectorType, 0);
160     LLVMTypes::Int16VectorPointerType = llvm::PointerType::get(LLVMTypes::Int16VectorType, 0);
161     LLVMTypes::Int32VectorPointerType = llvm::PointerType::get(LLVMTypes::Int32VectorType, 0);
162     LLVMTypes::Int64VectorPointerType = llvm::PointerType::get(LLVMTypes::Int64VectorType, 0);
163     LLVMTypes::FloatVectorPointerType = llvm::PointerType::get(LLVMTypes::FloatVectorType, 0);
164     LLVMTypes::DoubleVectorPointerType = llvm::PointerType::get(LLVMTypes::DoubleVectorType, 0);
165 
166     LLVMTypes::VoidPointerVectorType = g->target->is32Bit() ? LLVMTypes::Int32VectorType : LLVMTypes::Int64VectorType;
167 
168     LLVMTrue = llvm::ConstantInt::getTrue(*ctx);
169     LLVMFalse = llvm::ConstantInt::getFalse(*ctx);
170     LLVMTrueInStorage = llvm::ConstantInt::get(LLVMTypes::Int8Type, 0xff, false /*unsigned*/);
171     LLVMFalseInStorage = llvm::ConstantInt::get(LLVMTypes::Int8Type, 0x00, false /*unsigned*/);
172 
173     std::vector<llvm::Constant *> maskOnes;
174     llvm::Constant *onMask = NULL;
175     switch (target.getMaskBitCount()) {
176     case 1:
177         onMask = llvm::ConstantInt::get(llvm::Type::getInt1Ty(*ctx), 1, false /*unsigned*/); // 0x1
178         break;
179     case 8:
180         onMask = llvm::ConstantInt::get(llvm::Type::getInt8Ty(*ctx), -1, true /*signed*/); // 0xff
181         break;
182     case 16:
183         onMask = llvm::ConstantInt::get(llvm::Type::getInt16Ty(*ctx), -1, true /*signed*/); // 0xffff
184         break;
185     case 32:
186         onMask = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*ctx), -1, true /*signed*/); // 0xffffffff
187         break;
188     case 64:
189         onMask = llvm::ConstantInt::get(llvm::Type::getInt64Ty(*ctx), -1, true /*signed*/); // 0xffffffffffffffffull
190         break;
191     default:
192         FATAL("Unhandled mask width for onMask");
193     }
194 
195     for (int i = 0; i < target.getVectorWidth(); ++i)
196         maskOnes.push_back(onMask);
197     LLVMMaskAllOn = llvm::ConstantVector::get(maskOnes);
198 
199     std::vector<llvm::Constant *> maskZeros;
200     llvm::Constant *offMask = NULL;
201     switch (target.getMaskBitCount()) {
202     case 1:
203         offMask = llvm::ConstantInt::get(llvm::Type::getInt1Ty(*ctx), 0, true /*signed*/);
204         break;
205     case 8:
206         offMask = llvm::ConstantInt::get(llvm::Type::getInt8Ty(*ctx), 0, true /*signed*/);
207         break;
208     case 16:
209         offMask = llvm::ConstantInt::get(llvm::Type::getInt16Ty(*ctx), 0, true /*signed*/);
210         break;
211     case 32:
212         offMask = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*ctx), 0, true /*signed*/);
213         break;
214     case 64:
215         offMask = llvm::ConstantInt::get(llvm::Type::getInt64Ty(*ctx), 0, true /*signed*/);
216         break;
217     default:
218         FATAL("Unhandled mask width for offMask");
219     }
220     for (int i = 0; i < target.getVectorWidth(); ++i)
221         maskZeros.push_back(offMask);
222     LLVMMaskAllOff = llvm::ConstantVector::get(maskZeros);
223 }
224 
LLVMInt8(int8_t ival)225 llvm::ConstantInt *LLVMInt8(int8_t ival) {
226     return llvm::ConstantInt::get(llvm::Type::getInt8Ty(*g->ctx), ival, true /*signed*/);
227 }
228 
LLVMUInt8(uint8_t ival)229 llvm::ConstantInt *LLVMUInt8(uint8_t ival) {
230     return llvm::ConstantInt::get(llvm::Type::getInt8Ty(*g->ctx), ival, false /*unsigned*/);
231 }
232 
LLVMInt16(int16_t ival)233 llvm::ConstantInt *LLVMInt16(int16_t ival) {
234     return llvm::ConstantInt::get(llvm::Type::getInt16Ty(*g->ctx), ival, true /*signed*/);
235 }
236 
LLVMUInt16(uint16_t ival)237 llvm::ConstantInt *LLVMUInt16(uint16_t ival) {
238     return llvm::ConstantInt::get(llvm::Type::getInt16Ty(*g->ctx), ival, false /*unsigned*/);
239 }
240 
LLVMInt32(int32_t ival)241 llvm::ConstantInt *LLVMInt32(int32_t ival) {
242     return llvm::ConstantInt::get(llvm::Type::getInt32Ty(*g->ctx), ival, true /*signed*/);
243 }
244 
LLVMUInt32(uint32_t ival)245 llvm::ConstantInt *LLVMUInt32(uint32_t ival) {
246     return llvm::ConstantInt::get(llvm::Type::getInt32Ty(*g->ctx), ival, false /*unsigned*/);
247 }
248 
LLVMInt64(int64_t ival)249 llvm::ConstantInt *LLVMInt64(int64_t ival) {
250     return llvm::ConstantInt::get(llvm::Type::getInt64Ty(*g->ctx), ival, true /*signed*/);
251 }
252 
LLVMUInt64(uint64_t ival)253 llvm::ConstantInt *LLVMUInt64(uint64_t ival) {
254     return llvm::ConstantInt::get(llvm::Type::getInt64Ty(*g->ctx), ival, false /*unsigned*/);
255 }
256 
LLVMFloat(float fval)257 llvm::Constant *LLVMFloat(float fval) { return llvm::ConstantFP::get(llvm::Type::getFloatTy(*g->ctx), fval); }
258 
LLVMDouble(double dval)259 llvm::Constant *LLVMDouble(double dval) { return llvm::ConstantFP::get(llvm::Type::getDoubleTy(*g->ctx), dval); }
260 
LLVMInt8Vector(int8_t ival)261 llvm::Constant *LLVMInt8Vector(int8_t ival) {
262     llvm::Constant *v = LLVMInt8(ival);
263     std::vector<llvm::Constant *> vals;
264     for (int i = 0; i < g->target->getVectorWidth(); ++i)
265         vals.push_back(v);
266     return llvm::ConstantVector::get(vals);
267 }
268 
LLVMInt8Vector(const int8_t * ivec)269 llvm::Constant *LLVMInt8Vector(const int8_t *ivec) {
270     std::vector<llvm::Constant *> vals;
271     for (int i = 0; i < g->target->getVectorWidth(); ++i)
272         vals.push_back(LLVMInt8(ivec[i]));
273     return llvm::ConstantVector::get(vals);
274 }
275 
LLVMUInt8Vector(uint8_t ival)276 llvm::Constant *LLVMUInt8Vector(uint8_t ival) {
277     llvm::Constant *v = LLVMUInt8(ival);
278     std::vector<llvm::Constant *> vals;
279     for (int i = 0; i < g->target->getVectorWidth(); ++i)
280         vals.push_back(v);
281     return llvm::ConstantVector::get(vals);
282 }
283 
LLVMUInt8Vector(const uint8_t * ivec)284 llvm::Constant *LLVMUInt8Vector(const uint8_t *ivec) {
285     std::vector<llvm::Constant *> vals;
286     for (int i = 0; i < g->target->getVectorWidth(); ++i)
287         vals.push_back(LLVMUInt8(ivec[i]));
288     return llvm::ConstantVector::get(vals);
289 }
290 
LLVMInt16Vector(int16_t ival)291 llvm::Constant *LLVMInt16Vector(int16_t ival) {
292     llvm::Constant *v = LLVMInt16(ival);
293     std::vector<llvm::Constant *> vals;
294     for (int i = 0; i < g->target->getVectorWidth(); ++i)
295         vals.push_back(v);
296     return llvm::ConstantVector::get(vals);
297 }
298 
LLVMInt16Vector(const int16_t * ivec)299 llvm::Constant *LLVMInt16Vector(const int16_t *ivec) {
300     std::vector<llvm::Constant *> vals;
301     for (int i = 0; i < g->target->getVectorWidth(); ++i)
302         vals.push_back(LLVMInt16(ivec[i]));
303     return llvm::ConstantVector::get(vals);
304 }
305 
LLVMUInt16Vector(uint16_t ival)306 llvm::Constant *LLVMUInt16Vector(uint16_t ival) {
307     llvm::Constant *v = LLVMUInt16(ival);
308     std::vector<llvm::Constant *> vals;
309     for (int i = 0; i < g->target->getVectorWidth(); ++i)
310         vals.push_back(v);
311     return llvm::ConstantVector::get(vals);
312 }
313 
LLVMUInt16Vector(const uint16_t * ivec)314 llvm::Constant *LLVMUInt16Vector(const uint16_t *ivec) {
315     std::vector<llvm::Constant *> vals;
316     for (int i = 0; i < g->target->getVectorWidth(); ++i)
317         vals.push_back(LLVMUInt16(ivec[i]));
318     return llvm::ConstantVector::get(vals);
319 }
320 
LLVMInt32Vector(int32_t ival)321 llvm::Constant *LLVMInt32Vector(int32_t ival) {
322     llvm::Constant *v = LLVMInt32(ival);
323     std::vector<llvm::Constant *> vals;
324     for (int i = 0; i < g->target->getVectorWidth(); ++i)
325         vals.push_back(v);
326     return llvm::ConstantVector::get(vals);
327 }
328 
LLVMInt32Vector(const int32_t * ivec)329 llvm::Constant *LLVMInt32Vector(const int32_t *ivec) {
330     std::vector<llvm::Constant *> vals;
331     for (int i = 0; i < g->target->getVectorWidth(); ++i)
332         vals.push_back(LLVMInt32(ivec[i]));
333     return llvm::ConstantVector::get(vals);
334 }
335 
LLVMUInt32Vector(uint32_t ival)336 llvm::Constant *LLVMUInt32Vector(uint32_t ival) {
337     llvm::Constant *v = LLVMUInt32(ival);
338     std::vector<llvm::Constant *> vals;
339     for (int i = 0; i < g->target->getVectorWidth(); ++i)
340         vals.push_back(v);
341     return llvm::ConstantVector::get(vals);
342 }
343 
LLVMUInt32Vector(const uint32_t * ivec)344 llvm::Constant *LLVMUInt32Vector(const uint32_t *ivec) {
345     std::vector<llvm::Constant *> vals;
346     for (int i = 0; i < g->target->getVectorWidth(); ++i)
347         vals.push_back(LLVMUInt32(ivec[i]));
348     return llvm::ConstantVector::get(vals);
349 }
350 
LLVMFloatVector(float fval)351 llvm::Constant *LLVMFloatVector(float fval) {
352     llvm::Constant *v = LLVMFloat(fval);
353     std::vector<llvm::Constant *> vals;
354     for (int i = 0; i < g->target->getVectorWidth(); ++i)
355         vals.push_back(v);
356     return llvm::ConstantVector::get(vals);
357 }
358 
LLVMFloatVector(const float * fvec)359 llvm::Constant *LLVMFloatVector(const float *fvec) {
360     std::vector<llvm::Constant *> vals;
361     for (int i = 0; i < g->target->getVectorWidth(); ++i)
362         vals.push_back(LLVMFloat(fvec[i]));
363     return llvm::ConstantVector::get(vals);
364 }
365 
LLVMDoubleVector(double dval)366 llvm::Constant *LLVMDoubleVector(double dval) {
367     llvm::Constant *v = LLVMDouble(dval);
368     std::vector<llvm::Constant *> vals;
369     for (int i = 0; i < g->target->getVectorWidth(); ++i)
370         vals.push_back(v);
371     return llvm::ConstantVector::get(vals);
372 }
373 
LLVMDoubleVector(const double * dvec)374 llvm::Constant *LLVMDoubleVector(const double *dvec) {
375     std::vector<llvm::Constant *> vals;
376     for (int i = 0; i < g->target->getVectorWidth(); ++i)
377         vals.push_back(LLVMDouble(dvec[i]));
378     return llvm::ConstantVector::get(vals);
379 }
380 
LLVMInt64Vector(int64_t ival)381 llvm::Constant *LLVMInt64Vector(int64_t ival) {
382     llvm::Constant *v = LLVMInt64(ival);
383     std::vector<llvm::Constant *> vals;
384     for (int i = 0; i < g->target->getVectorWidth(); ++i)
385         vals.push_back(v);
386     return llvm::ConstantVector::get(vals);
387 }
388 
LLVMInt64Vector(const int64_t * ivec)389 llvm::Constant *LLVMInt64Vector(const int64_t *ivec) {
390     std::vector<llvm::Constant *> vals;
391     for (int i = 0; i < g->target->getVectorWidth(); ++i)
392         vals.push_back(LLVMInt64(ivec[i]));
393     return llvm::ConstantVector::get(vals);
394 }
395 
LLVMUInt64Vector(uint64_t ival)396 llvm::Constant *LLVMUInt64Vector(uint64_t ival) {
397     llvm::Constant *v = LLVMUInt64(ival);
398     std::vector<llvm::Constant *> vals;
399     for (int i = 0; i < g->target->getVectorWidth(); ++i)
400         vals.push_back(v);
401     return llvm::ConstantVector::get(vals);
402 }
403 
LLVMUInt64Vector(const uint64_t * ivec)404 llvm::Constant *LLVMUInt64Vector(const uint64_t *ivec) {
405     std::vector<llvm::Constant *> vals;
406     for (int i = 0; i < g->target->getVectorWidth(); ++i)
407         vals.push_back(LLVMUInt64(ivec[i]));
408     return llvm::ConstantVector::get(vals);
409 }
410 
LLVMBoolVector(bool b)411 llvm::Constant *LLVMBoolVector(bool b) {
412     llvm::Constant *v;
413     if (LLVMTypes::BoolVectorType == LLVMTypes::Int64VectorType)
414         v = llvm::ConstantInt::get(LLVMTypes::Int64Type, b ? 0xffffffffffffffffull : 0, false /*unsigned*/);
415     else if (LLVMTypes::BoolVectorType == LLVMTypes::Int32VectorType)
416         v = llvm::ConstantInt::get(LLVMTypes::Int32Type, b ? 0xffffffff : 0, false /*unsigned*/);
417     else if (LLVMTypes::BoolVectorType == LLVMTypes::Int16VectorType)
418         v = llvm::ConstantInt::get(LLVMTypes::Int16Type, b ? 0xffff : 0, false /*unsigned*/);
419     else if (LLVMTypes::BoolVectorType == LLVMTypes::Int8VectorType)
420         v = llvm::ConstantInt::get(LLVMTypes::Int8Type, b ? 0xff : 0, false /*unsigned*/);
421     else {
422         Assert(LLVMTypes::BoolVectorType == LLVMTypes::Int1VectorType);
423         v = b ? LLVMTrue : LLVMFalse;
424     }
425 
426     std::vector<llvm::Constant *> vals;
427     for (int i = 0; i < g->target->getVectorWidth(); ++i)
428         vals.push_back(v);
429     return llvm::ConstantVector::get(vals);
430 }
431 
LLVMBoolVector(const bool * bvec)432 llvm::Constant *LLVMBoolVector(const bool *bvec) {
433     std::vector<llvm::Constant *> vals;
434     for (int i = 0; i < g->target->getVectorWidth(); ++i) {
435         llvm::Constant *v;
436         if (LLVMTypes::BoolVectorType == LLVMTypes::Int64VectorType)
437             v = llvm::ConstantInt::get(LLVMTypes::Int64Type, bvec[i] ? 0xffffffffffffffffull : 0, false /*unsigned*/);
438         else if (LLVMTypes::BoolVectorType == LLVMTypes::Int32VectorType)
439             v = llvm::ConstantInt::get(LLVMTypes::Int32Type, bvec[i] ? 0xffffffff : 0, false /*unsigned*/);
440         else if (LLVMTypes::BoolVectorType == LLVMTypes::Int16VectorType)
441             v = llvm::ConstantInt::get(LLVMTypes::Int16Type, bvec[i] ? 0xffff : 0, false /*unsigned*/);
442         else if (LLVMTypes::BoolVectorType == LLVMTypes::Int8VectorType)
443             v = llvm::ConstantInt::get(LLVMTypes::Int8Type, bvec[i] ? 0xff : 0, false /*unsigned*/);
444         else {
445             Assert(LLVMTypes::BoolVectorType == LLVMTypes::Int1VectorType);
446             v = bvec[i] ? LLVMTrue : LLVMFalse;
447         }
448 
449         vals.push_back(v);
450     }
451     return llvm::ConstantVector::get(vals);
452 }
453 
LLVMBoolVectorInStorage(bool b)454 llvm::Constant *LLVMBoolVectorInStorage(bool b) {
455     llvm::Constant *v = b ? LLVMTrueInStorage : LLVMFalseInStorage;
456     std::vector<llvm::Constant *> vals;
457     for (int i = 0; i < g->target->getVectorWidth(); ++i)
458         vals.push_back(v);
459     return llvm::ConstantVector::get(vals);
460 }
461 
LLVMBoolVectorInStorage(const bool * bvec)462 llvm::Constant *LLVMBoolVectorInStorage(const bool *bvec) {
463     std::vector<llvm::Constant *> vals;
464     for (int i = 0; i < g->target->getVectorWidth(); ++i) {
465         llvm::Constant *v = llvm::ConstantInt::get(LLVMTypes::Int8Type, bvec[i] ? 0xff : 0, false /*unsigned*/);
466         vals.push_back(v);
467     }
468     return llvm::ConstantVector::get(vals);
469 }
470 
LLVMIntAsType(int64_t val,llvm::Type * type)471 llvm::Constant *LLVMIntAsType(int64_t val, llvm::Type *type) {
472 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
473     llvm::FixedVectorType *vecType = llvm::dyn_cast<llvm::FixedVectorType>(type);
474 #else
475     llvm::VectorType *vecType = llvm::dyn_cast<llvm::VectorType>(type);
476 #endif
477 
478     if (vecType != NULL) {
479         llvm::Constant *v = llvm::ConstantInt::get(vecType->getElementType(), val, true /* signed */);
480         std::vector<llvm::Constant *> vals;
481         for (int i = 0; i < (int)vecType->getNumElements(); ++i)
482             vals.push_back(v);
483         return llvm::ConstantVector::get(vals);
484     } else
485         return llvm::ConstantInt::get(type, val, true /* signed */);
486 }
487 
LLVMUIntAsType(uint64_t val,llvm::Type * type)488 llvm::Constant *LLVMUIntAsType(uint64_t val, llvm::Type *type) {
489 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
490     llvm::FixedVectorType *vecType = llvm::dyn_cast<llvm::FixedVectorType>(type);
491 #else
492     llvm::VectorType *vecType = llvm::dyn_cast<llvm::VectorType>(type);
493 #endif
494 
495     if (vecType != NULL) {
496         llvm::Constant *v = llvm::ConstantInt::get(vecType->getElementType(), val, false /* unsigned */);
497         std::vector<llvm::Constant *> vals;
498         for (int i = 0; i < (int)vecType->getNumElements(); ++i)
499             vals.push_back(v);
500         return llvm::ConstantVector::get(vals);
501     } else
502         return llvm::ConstantInt::get(type, val, false /* unsigned */);
503 }
504 
505 /** Conservative test to see if two llvm::Values are equal.  There are
506     (potentially many) cases where the two values actually are equal but
507     this will return false.  However, if it does return true, the two
508     vectors definitely are equal.
509 */
lValuesAreEqual(llvm::Value * v0,llvm::Value * v1,std::vector<llvm::PHINode * > & seenPhi0,std::vector<llvm::PHINode * > & seenPhi1)510 static bool lValuesAreEqual(llvm::Value *v0, llvm::Value *v1, std::vector<llvm::PHINode *> &seenPhi0,
511                             std::vector<llvm::PHINode *> &seenPhi1) {
512     // Thanks to the fact that LLVM hashes and returns the same pointer for
513     // constants (of all sorts, even constant expressions), this first test
514     // actually catches a lot of cases.  LLVM's SSA form also helps a lot
515     // with this..
516     if (v0 == v1)
517         return true;
518 
519     Assert(seenPhi0.size() == seenPhi1.size());
520     for (unsigned int i = 0; i < seenPhi0.size(); ++i)
521         if (v0 == seenPhi0[i] && v1 == seenPhi1[i])
522             return true;
523 
524     llvm::BinaryOperator *bo0 = llvm::dyn_cast<llvm::BinaryOperator>(v0);
525     llvm::BinaryOperator *bo1 = llvm::dyn_cast<llvm::BinaryOperator>(v1);
526     if (bo0 != NULL && bo1 != NULL) {
527         if (bo0->getOpcode() != bo1->getOpcode())
528             return false;
529         return (lValuesAreEqual(bo0->getOperand(0), bo1->getOperand(0), seenPhi0, seenPhi1) &&
530                 lValuesAreEqual(bo0->getOperand(1), bo1->getOperand(1), seenPhi0, seenPhi1));
531     }
532 
533     llvm::CastInst *cast0 = llvm::dyn_cast<llvm::CastInst>(v0);
534     llvm::CastInst *cast1 = llvm::dyn_cast<llvm::CastInst>(v1);
535     if (cast0 != NULL && cast1 != NULL) {
536         if (cast0->getOpcode() != cast1->getOpcode())
537             return false;
538         return lValuesAreEqual(cast0->getOperand(0), cast1->getOperand(0), seenPhi0, seenPhi1);
539     }
540 
541     llvm::PHINode *phi0 = llvm::dyn_cast<llvm::PHINode>(v0);
542     llvm::PHINode *phi1 = llvm::dyn_cast<llvm::PHINode>(v1);
543     if (phi0 != NULL && phi1 != NULL) {
544         if (phi0->getNumIncomingValues() != phi1->getNumIncomingValues())
545             return false;
546 
547         seenPhi0.push_back(phi0);
548         seenPhi1.push_back(phi1);
549 
550         unsigned int numIncoming = phi0->getNumIncomingValues();
551         // Check all of the incoming values: if all of them are all equal,
552         // then we're good.
553         bool anyFailure = false;
554         for (unsigned int i = 0; i < numIncoming; ++i) {
555             // FIXME: should it be ok if the incoming blocks are different,
556             // where we just return faliure in this case?
557             Assert(phi0->getIncomingBlock(i) == phi1->getIncomingBlock(i));
558             if (!lValuesAreEqual(phi0->getIncomingValue(i), phi1->getIncomingValue(i), seenPhi0, seenPhi1)) {
559                 anyFailure = true;
560                 break;
561             }
562         }
563 
564         seenPhi0.pop_back();
565         seenPhi1.pop_back();
566 
567         return !anyFailure;
568     }
569 
570     return false;
571 }
572 
573 /** Given an llvm::Value known to be an integer, return its value as
574     an int64_t.
575 */
lGetIntValue(llvm::Value * offset)576 static int64_t lGetIntValue(llvm::Value *offset) {
577     llvm::ConstantInt *intOffset = llvm::dyn_cast<llvm::ConstantInt>(offset);
578     Assert(intOffset && (intOffset->getBitWidth() == 32 || intOffset->getBitWidth() == 64));
579     return intOffset->getSExtValue();
580 }
581 
582 /**  Recognizes constant vector with undef operands except the first one:
583  *   <i64 4, i64 undef, i64 undef, i64 undef, i64 undef, i64 undef, i64 undef, i64 undef>
584  */
lIsFirstElementConstVector(llvm::Value * v)585 static bool lIsFirstElementConstVector(llvm::Value *v) {
586     llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(v);
587     // FIXME: skipping instruction and using getOperand(1) without checking for instruction type is incorrect!
588     // This yields failing tests/write-same-loc.ispc for 32 bit targets (x86).
589     // Need to understand what initial intent was here (what instruction supposed to be handled).
590     // TODO: after fixing FIXME above isGenXTarget() needs to be removed.
591     if (g->target->isGenXTarget()) {
592         if (cv == NULL && llvm::isa<llvm::Instruction>(v)) {
593             cv = llvm::dyn_cast<llvm::ConstantVector>(llvm::dyn_cast<llvm::Instruction>(v)->getOperand(1));
594         }
595     }
596     if (cv != NULL) {
597         llvm::Constant *c = llvm::dyn_cast<llvm::Constant>(cv->getOperand(0));
598         if (c == NULL) {
599             return false;
600         }
601 
602         for (int i = 1; i < (int)cv->getNumOperands(); ++i) {
603             if (!llvm::isa<llvm::UndefValue>(cv->getOperand(i))) {
604                 return false;
605             }
606         }
607         return true;
608     }
609     return false;
610 }
611 
LLVMFlattenInsertChain(llvm::Value * inst,int vectorWidth,bool compare,bool undef,bool searchFirstUndef)612 llvm::Value *LLVMFlattenInsertChain(llvm::Value *inst, int vectorWidth, bool compare, bool undef,
613                                     bool searchFirstUndef) {
614     std::vector<llvm::Value *> elements(vectorWidth, nullptr);
615 
616     // Catch a pattern of InsertElement chain.
617     if (llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(inst)) {
618         // Gather elements of vector
619         while (ie != NULL) {
620             int64_t iOffset = lGetIntValue(ie->getOperand(2));
621             Assert(iOffset >= 0 && iOffset < vectorWidth);
622 
623             // Get the scalar value from this insert
624             if (elements[iOffset] == NULL) {
625                 elements[iOffset] = ie->getOperand(1);
626             }
627 
628             // Do we have another insert?
629             llvm::Value *insertBase = ie->getOperand(0);
630             ie = llvm::dyn_cast<llvm::InsertElementInst>(insertBase);
631             if (ie != NULL) {
632                 continue;
633             }
634 
635             if (llvm::isa<llvm::UndefValue>(insertBase)) {
636                 break;
637             }
638 
639             if (llvm::isa<llvm::ConstantVector>(insertBase) || llvm::isa<llvm::ConstantAggregateZero>(insertBase)) {
640                 llvm::Constant *cv = llvm::dyn_cast<llvm::Constant>(insertBase);
641                 Assert(vectorWidth == (int)(cv->getNumOperands()));
642                 for (int i = 0; i < vectorWidth; i++) {
643                     if (elements[i] == NULL) {
644                         elements[i] = cv->getOperand(i);
645                     }
646                 }
647                 break;
648             } else {
649                 // Here chain ends in llvm::LoadInst or some other.
650                 // They are not equal to each other so we should return NULL if compare
651                 // and first element if we have it.
652                 Assert(compare == true || elements[0] != NULL);
653                 if (compare) {
654                     return NULL;
655                 } else {
656                     return elements[0];
657                 }
658             }
659             // TODO: Also, should we handle some other values like
660             // ConstantDataVectors.
661         }
662         if (compare == false) {
663             // We simply want first element
664             return elements[0];
665         }
666 
667         int null_number = 0;
668         int NonNull = 0;
669         for (int i = 0; i < vectorWidth; i++) {
670             if (elements[i] == NULL) {
671                 null_number++;
672             } else {
673                 NonNull = i;
674             }
675         }
676         if (null_number == vectorWidth) {
677             // All of elements are NULLs
678             return NULL;
679         }
680         if ((undef == false) && (null_number != 0)) {
681             // We don't want NULLs in chain, but we have them
682             return NULL;
683         }
684 
685         // Compare elements of vector
686         for (int i = 0; i < vectorWidth; i++) {
687             if (elements[i] == NULL) {
688                 continue;
689             }
690 
691             std::vector<llvm::PHINode *> seenPhi0;
692             std::vector<llvm::PHINode *> seenPhi1;
693             if (lValuesAreEqual(elements[NonNull], elements[i], seenPhi0, seenPhi1) == false) {
694                 return NULL;
695             }
696         }
697         return elements[NonNull];
698     }
699 
700     // Catch a pattern of broadcast implemented as InsertElement + Shuffle:
701     //   %broadcast_init.0 = insertelement <4 x i32> undef, i32 %val, i32 0
702     //   %broadcast.1 = shufflevector <4 x i32> %smear.0, <4 x i32> undef,
703     //                                              <4 x i32> zeroinitializer
704     // Or:
705     //   %gep_ptr2int_broadcast_init = insertelement <8 x i64> undef, i64 %gep_ptr2int, i32 0
706     //   %0 = add <8 x i64> %gep_ptr2int_broadcast_init,
707     //              <i64 4, i64 undef, i64 undef, i64 undef, i64 undef, i64 undef, i64 undef, i64 undef>
708     //   %gep_offset = shufflevector <8 x i64> %0, <8 x i64> undef, <8 x i32> zeroinitializer
709     else if (llvm::ShuffleVectorInst *shuf = llvm::dyn_cast<llvm::ShuffleVectorInst>(inst)) {
710 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
711         llvm::Value *indices = shuf->getShuffleMaskForBitcode();
712 #else
713         llvm::Value *indices = shuf->getOperand(2);
714 #endif
715 
716         if (llvm::isa<llvm::ConstantAggregateZero>(indices)) {
717             llvm::Value *op = shuf->getOperand(0);
718             llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(op);
719             if (ie == NULL && searchFirstUndef) {
720                 // Trying to recognize 2nd pattern
721                 llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(op);
722                 if (bop != NULL && ((bop->getOpcode() == llvm::Instruction::Add) || IsOrEquivalentToAdd(bop))) {
723                     if (lIsFirstElementConstVector(bop->getOperand(1))) {
724                         ie = llvm::dyn_cast<llvm::InsertElementInst>(bop->getOperand(0));
725                     } else if (llvm::isa<llvm::InsertElementInst>(bop->getOperand(1))) {
726                         // Or shuffle vector can accept insertelement itself
727                         ie = llvm::cast<llvm::InsertElementInst>(bop->getOperand(1));
728                     }
729                 }
730             }
731             if (ie != NULL && llvm::isa<llvm::UndefValue>(ie->getOperand(0))) {
732                 llvm::ConstantInt *ci = llvm::dyn_cast<llvm::ConstantInt>(ie->getOperand(2));
733                 if (ci->isZero()) {
734                     return ie->getOperand(1);
735                 }
736             }
737         }
738     }
739     return NULL;
740 }
741 
LLVMExtractVectorInts(llvm::Value * v,int64_t ret[],int * nElts)742 bool LLVMExtractVectorInts(llvm::Value *v, int64_t ret[], int *nElts) {
743     // Make sure we do in fact have a vector of integer values here
744 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
745     llvm::FixedVectorType *vt = llvm::dyn_cast<llvm::FixedVectorType>(v->getType());
746 #else
747     llvm::VectorType *vt = llvm::dyn_cast<llvm::VectorType>(v->getType());
748 #endif
749     Assert(vt != NULL);
750     Assert(llvm::isa<llvm::IntegerType>(vt->getElementType()));
751 
752     *nElts = (int)vt->getNumElements();
753 
754     if (llvm::isa<llvm::ConstantAggregateZero>(v)) {
755         for (int i = 0; i < (int)vt->getNumElements(); ++i)
756             ret[i] = 0;
757         return true;
758     }
759 
760     llvm::ConstantDataVector *cv = llvm::dyn_cast<llvm::ConstantDataVector>(v);
761     if (cv == NULL)
762         return false;
763 
764     for (int i = 0; i < (int)cv->getNumElements(); ++i)
765         ret[i] = cv->getElementAsInteger(i);
766     return true;
767 }
768 
769 static bool lVectorValuesAllEqual(llvm::Value *v, int vectorLength, std::vector<llvm::PHINode *> &seenPhis,
770                                   llvm::Value **splatValue = NULL);
771 
772 /** This function checks to see if the given (scalar or vector) value is an
773     exact multiple of baseValue.  It returns true if so, and false if not
774     (or if it's not able to determine if it is).  Any vector value passed
775     in is required to have the same value in all elements (so that we can
776     just check the first element to be a multiple of the given value.)
777  */
lIsExactMultiple(llvm::Value * val,int baseValue,int vectorLength,std::vector<llvm::PHINode * > & seenPhis)778 static bool lIsExactMultiple(llvm::Value *val, int baseValue, int vectorLength,
779                              std::vector<llvm::PHINode *> &seenPhis) {
780     if (llvm::isa<llvm::VectorType>(val->getType()) == false) {
781         // If we've worked down to a constant int, then the moment of truth
782         // has arrived...
783         llvm::ConstantInt *ci = llvm::dyn_cast<llvm::ConstantInt>(val);
784         if (ci != NULL)
785             return (ci->getZExtValue() % baseValue) == 0;
786     } else
787         Assert(LLVMVectorValuesAllEqual(val));
788 
789     if (llvm::isa<llvm::InsertElementInst>(val) || llvm::isa<llvm::ShuffleVectorInst>(val)) {
790         llvm::Value *element = LLVMFlattenInsertChain(val, g->target->getVectorWidth());
791         // We just need to check the scalar first value, since we know that
792         // all elements are equal
793         return element ? lIsExactMultiple(element, baseValue, vectorLength, seenPhis) : false;
794     }
795 
796     llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(val);
797     if (phi != NULL) {
798         for (unsigned int i = 0; i < seenPhis.size(); ++i)
799             if (phi == seenPhis[i])
800                 return true;
801 
802         seenPhis.push_back(phi);
803         unsigned int numIncoming = phi->getNumIncomingValues();
804 
805         // Check all of the incoming values: if all of them pass, then
806         // we're good.
807         for (unsigned int i = 0; i < numIncoming; ++i) {
808             llvm::Value *incoming = phi->getIncomingValue(i);
809             bool mult = lIsExactMultiple(incoming, baseValue, vectorLength, seenPhis);
810             if (mult == false) {
811                 seenPhis.pop_back();
812                 return false;
813             }
814         }
815         seenPhis.pop_back();
816         return true;
817     }
818 
819     llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(val);
820     if (bop != NULL && ((bop->getOpcode() == llvm::Instruction::Add) || IsOrEquivalentToAdd(bop))) {
821         llvm::Value *op0 = bop->getOperand(0);
822         llvm::Value *op1 = bop->getOperand(1);
823 
824         bool be0 = lIsExactMultiple(op0, baseValue, vectorLength, seenPhis);
825         bool be1 = lIsExactMultiple(op1, baseValue, vectorLength, seenPhis);
826         return (be0 && be1);
827     }
828     // FIXME: mul? casts? ... ?
829 
830     return false;
831 }
832 
833 /** Returns the next power of two greater than or equal to the given
834     value. */
lRoundUpPow2(int v)835 static int lRoundUpPow2(int v) {
836     v--;
837     v |= v >> 1;
838     v |= v >> 2;
839     v |= v >> 4;
840     v |= v >> 8;
841     v |= v >> 16;
842     return v + 1;
843 }
844 
845 /** Try to determine if all of the elements of the given vector value have
846     the same value when divided by the given baseValue.  The function
847     returns true if this can be determined to be the case, and false
848     otherwise.  (This function may fail to identify some cases where it
849     does in fact have this property, but should never report a given value
850     as being a multiple if it isn't!)
851  */
lAllDivBaseEqual(llvm::Value * val,int64_t baseValue,int vectorLength,std::vector<llvm::PHINode * > & seenPhis,bool & canAdd)852 static bool lAllDivBaseEqual(llvm::Value *val, int64_t baseValue, int vectorLength,
853                              std::vector<llvm::PHINode *> &seenPhis, bool &canAdd) {
854     Assert(llvm::isa<llvm::VectorType>(val->getType()));
855     // Make sure the base value is a positive power of 2
856     Assert(baseValue > 0 && (baseValue & (baseValue - 1)) == 0);
857 
858     // The easy case
859     if (lVectorValuesAllEqual(val, vectorLength, seenPhis))
860         return true;
861 
862     int64_t vecVals[ISPC_MAX_NVEC];
863     int nElts;
864     if (llvm::isa<llvm::VectorType>(val->getType()) && LLVMExtractVectorInts(val, vecVals, &nElts)) {
865         // If we have a vector of compile-time constant integer values,
866         // then go ahead and check them directly..
867         int64_t firstDiv = vecVals[0] / baseValue;
868         for (int i = 1; i < nElts; ++i)
869             if ((vecVals[i] / baseValue) != firstDiv)
870                 return false;
871 
872         return true;
873     }
874 
875     llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(val);
876     if (phi != NULL) {
877         for (unsigned int i = 0; i < seenPhis.size(); ++i)
878             if (phi == seenPhis[i])
879                 return true;
880 
881         seenPhis.push_back(phi);
882         unsigned int numIncoming = phi->getNumIncomingValues();
883 
884         // Check all of the incoming values: if all of them pass, then
885         // we're good.
886         for (unsigned int i = 0; i < numIncoming; ++i) {
887             llvm::Value *incoming = phi->getIncomingValue(i);
888             bool ca = canAdd;
889             bool mult = lAllDivBaseEqual(incoming, baseValue, vectorLength, seenPhis, ca);
890             if (mult == false) {
891                 seenPhis.pop_back();
892                 return false;
893             }
894         }
895         seenPhis.pop_back();
896         return true;
897     }
898 
899     llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(val);
900     if (bop != NULL && ((bop->getOpcode() == llvm::Instruction::Add) || IsOrEquivalentToAdd(bop)) && canAdd == true) {
901         llvm::Value *op0 = bop->getOperand(0);
902         llvm::Value *op1 = bop->getOperand(1);
903 
904         // Otherwise we're only going to worry about the following case,
905         // which comes up often when looping over SOA data:
906         // ashr %val, <constant shift>
907         // where %val = add %smear, <0,1,2,3...>
908         // and where the maximum of the <0,...> vector in the add is less than
909         //   1<<(constant shift),
910         // and where %smear is a smear of a value that is a multiple of
911         //   baseValue.
912 
913         int64_t addConstants[ISPC_MAX_NVEC];
914         if (LLVMExtractVectorInts(op1, addConstants, &nElts) == false)
915             return false;
916         Assert(nElts == vectorLength);
917 
918         // Do all of them give the same value when divided by baseValue?
919         int64_t firstConstDiv = addConstants[0] / baseValue;
920         for (int i = 1; i < vectorLength; ++i)
921             if ((addConstants[i] / baseValue) != firstConstDiv)
922                 return false;
923 
924         if (lVectorValuesAllEqual(op0, vectorLength, seenPhis) == false)
925             return false;
926 
927         // Note that canAdd is a reference parameter; setting this ensures
928         // that we don't allow multiple adds in other parts of the chain of
929         // dependent values from here.
930         canAdd = false;
931 
932         // Now we need to figure out the required alignment (in numbers of
933         // elements of the underlying type being indexed) of the value to
934         // which these integer addConstant[] values are being added to.  We
935         // know that we have addConstant[] values that all give the same
936         // value when divided by baseValue, but we may need a less-strict
937         // alignment than baseValue depending on the actual values.
938         //
939         // As an example, consider a case where the baseValue alignment is
940         // 16, but the addConstants here are <0,1,2,3>.  In that case, the
941         // value to which addConstants is added to only needs to be a
942         // multiple of 4.  Conversely, if addConstants are <4,5,6,7>, then
943         // we need a multiple of 8 to ensure that the final added result
944         // will still have the same value for all vector elements when
945         // divided by baseValue.
946         //
947         // All that said, here we need to find the maximum value of any of
948         // the addConstants[], mod baseValue.  If we round that up to the
949         // next power of 2, we'll have a value that will be no greater than
950         // baseValue and sometimes less.
951         int maxMod = int(addConstants[0] % baseValue);
952         for (int i = 1; i < vectorLength; ++i)
953             maxMod = std::max(maxMod, int(addConstants[i] % baseValue));
954         int requiredAlignment = lRoundUpPow2(maxMod);
955 
956         std::vector<llvm::PHINode *> seenPhisEEM;
957         return lIsExactMultiple(op0, requiredAlignment, vectorLength, seenPhisEEM);
958     }
959     // TODO: could handle mul by a vector of equal constant integer values
960     // and the like here and adjust the 'baseValue' value when it evenly
961     // divides, but unclear if it's worthwhile...
962 
963     return false;
964 }
965 
966 /** Given a vector shift right of some value by some amount, try to
967     determine if all of the elements of the final result have the same
968     value (i.e. whether the high bits are all equal, disregarding the low
969     bits that are shifted out.)  Returns true if so, and false otherwise.
970  */
lVectorShiftRightAllEqual(llvm::Value * val,llvm::Value * shift,int vectorLength)971 static bool lVectorShiftRightAllEqual(llvm::Value *val, llvm::Value *shift, int vectorLength) {
972     // Are we shifting all elements by a compile-time constant amount?  If
973     // not, give up.
974     int64_t shiftAmount[ISPC_MAX_NVEC];
975     int nElts;
976     if (LLVMExtractVectorInts(shift, shiftAmount, &nElts) == false)
977         return false;
978     Assert(nElts == vectorLength);
979 
980     // Is it the same amount for all elements?
981     for (int i = 0; i < vectorLength; ++i)
982         if (shiftAmount[i] != shiftAmount[0])
983             return false;
984 
985     // Now see if the value divided by (1 << shift) can be determined to
986     // have the same value for all vector elements.
987     int pow2 = 1 << shiftAmount[0];
988     bool canAdd = true;
989     std::vector<llvm::PHINode *> seenPhis;
990     bool eq = lAllDivBaseEqual(val, pow2, vectorLength, seenPhis, canAdd);
991 #if 0
992     fprintf(stderr, "check all div base equal:\n");
993     LLVMDumpValue(shift);
994     LLVMDumpValue(val);
995     fprintf(stderr, "----> %s\n\n", eq ? "true" : "false");
996 #endif
997     return eq;
998 }
999 
lVectorValuesAllEqual(llvm::Value * v,int vectorLength,std::vector<llvm::PHINode * > & seenPhis,llvm::Value ** splatValue)1000 static bool lVectorValuesAllEqual(llvm::Value *v, int vectorLength, std::vector<llvm::PHINode *> &seenPhis,
1001                                   llvm::Value **splatValue) {
1002     if (vectorLength == 1)
1003         return true;
1004 
1005     if (llvm::isa<llvm::ConstantAggregateZero>(v)) {
1006         if (splatValue) {
1007             llvm::ConstantAggregateZero *caz = llvm::dyn_cast<llvm::ConstantAggregateZero>(v);
1008             *splatValue = caz->getSequentialElement();
1009         }
1010         return true;
1011     }
1012 
1013     llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(v);
1014     if (cv != NULL) {
1015         llvm::Value *splat = cv->getSplatValue();
1016         if (splat != NULL && splatValue) {
1017             *splatValue = splat;
1018         }
1019         return (splat != NULL);
1020     }
1021 
1022     llvm::ConstantDataVector *cdv = llvm::dyn_cast<llvm::ConstantDataVector>(v);
1023     if (cdv != NULL) {
1024         llvm::Value *splat = cdv->getSplatValue();
1025         if (splat != NULL && splatValue) {
1026             *splatValue = splat;
1027         }
1028         return (splat != NULL);
1029     }
1030 
1031     llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
1032     if (bop != NULL) {
1033         // Easy case: both operands are all equal -> return true
1034         if (lVectorValuesAllEqual(bop->getOperand(0), vectorLength, seenPhis) &&
1035             lVectorValuesAllEqual(bop->getOperand(1), vectorLength, seenPhis))
1036             return true;
1037 
1038         // If it's a shift, take a special path that tries to check if the
1039         // high (surviving) bits of the values are equal.
1040         if (bop->getOpcode() == llvm::Instruction::AShr || bop->getOpcode() == llvm::Instruction::LShr)
1041             return lVectorShiftRightAllEqual(bop->getOperand(0), bop->getOperand(1), vectorLength);
1042 
1043         return false;
1044     }
1045 
1046     llvm::CastInst *cast = llvm::dyn_cast<llvm::CastInst>(v);
1047     if (cast != NULL)
1048         return lVectorValuesAllEqual(cast->getOperand(0), vectorLength, seenPhis);
1049 
1050     llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(v);
1051     if (ie != NULL) {
1052         return (LLVMFlattenInsertChain(ie, vectorLength) != NULL);
1053     }
1054 
1055     llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(v);
1056     if (phi) {
1057         for (unsigned int i = 0; i < seenPhis.size(); ++i)
1058             if (seenPhis[i] == phi)
1059                 return true;
1060 
1061         seenPhis.push_back(phi);
1062 
1063         unsigned int numIncoming = phi->getNumIncomingValues();
1064         // Check all of the incoming values: if all of them are all equal,
1065         // then we're good.
1066         for (unsigned int i = 0; i < numIncoming; ++i) {
1067             if (!lVectorValuesAllEqual(phi->getIncomingValue(i), vectorLength, seenPhis)) {
1068                 seenPhis.pop_back();
1069                 return false;
1070             }
1071         }
1072 
1073         seenPhis.pop_back();
1074         return true;
1075     }
1076 
1077     if (llvm::isa<llvm::UndefValue>(v))
1078         // ?
1079         return false;
1080 
1081     Assert(!llvm::isa<llvm::Constant>(v));
1082 
1083     if (llvm::isa<llvm::CallInst>(v) || llvm::isa<llvm::LoadInst>(v) || !llvm::isa<llvm::Instruction>(v))
1084         return false;
1085 
1086     llvm::ShuffleVectorInst *shuffle = llvm::dyn_cast<llvm::ShuffleVectorInst>(v);
1087     if (shuffle != NULL) {
1088 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
1089         llvm::Value *indices = shuffle->getShuffleMaskForBitcode();
1090 #else
1091         llvm::Value *indices = shuffle->getOperand(2);
1092 #endif
1093 
1094         if (lVectorValuesAllEqual(indices, vectorLength, seenPhis))
1095             // The easy case--just a smear of the same element across the
1096             // whole vector.
1097             return true;
1098 
1099         // TODO: handle more general cases?
1100         return false;
1101     }
1102 
1103 #if 0
1104     fprintf(stderr, "all equal: ");
1105     v->dump();
1106     fprintf(stderr, "\n");
1107     llvm::Instruction *inst = llvm::dyn_cast<llvm::Instruction>(v);
1108     if (inst) {
1109         inst->getParent()->dump();
1110         fprintf(stderr, "\n");
1111         fprintf(stderr, "\n");
1112     }
1113 #endif
1114 
1115     return false;
1116 }
1117 
1118 /** Tests to see if all of the elements of the vector in the 'v' parameter
1119     are equal.  This is a conservative test and may return false for arrays
1120     where the values are actually all equal.
1121 */
LLVMVectorValuesAllEqual(llvm::Value * v,llvm::Value ** splat)1122 bool LLVMVectorValuesAllEqual(llvm::Value *v, llvm::Value **splat) {
1123 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
1124     llvm::FixedVectorType *vt = llvm::dyn_cast<llvm::FixedVectorType>(v->getType());
1125 #else
1126     llvm::VectorType *vt = llvm::dyn_cast<llvm::VectorType>(v->getType());
1127 #endif
1128     Assert(vt != NULL);
1129     int vectorLength = vt->getNumElements();
1130 
1131     std::vector<llvm::PHINode *> seenPhis;
1132     bool equal = lVectorValuesAllEqual(v, vectorLength, seenPhis, splat);
1133 
1134     Debug(SourcePos(), "LLVMVectorValuesAllEqual(%s) -> %s.", v->getName().str().c_str(), equal ? "true" : "false");
1135 #ifndef ISPC_NO_DUMPS
1136     if (g->debugPrint)
1137         LLVMDumpValue(v);
1138 #endif
1139 
1140     return equal;
1141 }
1142 
1143 /** Tests to see if a binary operator has an OR which is equivalent to an ADD.*/
IsOrEquivalentToAdd(llvm::Value * op)1144 bool IsOrEquivalentToAdd(llvm::Value *op) {
1145     bool isEq = false;
1146     llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(op);
1147     if (bop != NULL && bop->getOpcode() == llvm::Instruction::Or) {
1148         // Special case when A+B --> A|B transformation is triggered
1149         // We need to prove that A|B == A+B
1150         llvm::Module *module = bop->getParent()->getParent()->getParent();
1151         llvm::Value *op0 = bop->getOperand(0), *op1 = bop->getOperand(1);
1152         if (!haveNoCommonBitsSet(op0, op1, module->getDataLayout()) == false) {
1153             // Fallback to A+B case
1154             isEq = true;
1155         }
1156     }
1157     return isEq;
1158 }
1159 
1160 static bool lVectorIsLinear(llvm::Value *v, int vectorLength, int stride, std::vector<llvm::PHINode *> &seenPhis);
1161 
1162 /** Given a vector of compile-time constant integer values, test to see if
1163     they are a linear sequence of constant integers starting from an
1164     arbirary value but then having a step of value "stride" between
1165     elements.
1166  */
lVectorIsLinearConstantInts(llvm::ConstantDataVector * cv,int vectorLength,int stride)1167 static bool lVectorIsLinearConstantInts(llvm::ConstantDataVector *cv, int vectorLength, int stride) {
1168     // Flatten the vector out into the elements array
1169     llvm::SmallVector<llvm::Constant *, ISPC_MAX_NVEC> elements;
1170     for (int i = 0; i < (int)cv->getNumElements(); ++i)
1171         elements.push_back(cv->getElementAsConstant(i));
1172     Assert((int)elements.size() == vectorLength);
1173 
1174     llvm::ConstantInt *ci = llvm::dyn_cast<llvm::ConstantInt>(elements[0]);
1175     if (ci == NULL)
1176         // Not a vector of integers
1177         return false;
1178 
1179     int64_t prevVal = ci->getSExtValue();
1180 
1181     // For each element in the array, see if it is both a ConstantInt and
1182     // if the difference between it and the value of the previous element
1183     // is stride.  If not, fail.
1184     for (int i = 1; i < vectorLength; ++i) {
1185         ci = llvm::dyn_cast<llvm::ConstantInt>(elements[i]);
1186         if (ci == NULL)
1187             return false;
1188 
1189         int64_t nextVal = ci->getSExtValue();
1190         if (prevVal + stride != nextVal)
1191             return false;
1192 
1193         prevVal = nextVal;
1194     }
1195     return true;
1196 }
1197 
1198 /** Checks to see if (op0 * op1) is a linear vector where the result is a
1199     vector with values that increase by stride.
1200  */
lCheckMulForLinear(llvm::Value * op0,llvm::Value * op1,int vectorLength,int stride,std::vector<llvm::PHINode * > & seenPhis)1201 static bool lCheckMulForLinear(llvm::Value *op0, llvm::Value *op1, int vectorLength, int stride,
1202                                std::vector<llvm::PHINode *> &seenPhis) {
1203     // Is the first operand a constant integer value splatted across all of
1204     // the lanes?
1205     llvm::ConstantDataVector *cv = llvm::dyn_cast<llvm::ConstantDataVector>(op0);
1206     if (cv == NULL)
1207         return false;
1208 
1209     llvm::Constant *csplat = cv->getSplatValue();
1210     if (csplat == NULL)
1211         return false;
1212 
1213     llvm::ConstantInt *splat = llvm::dyn_cast<llvm::ConstantInt>(csplat);
1214     if (splat == NULL)
1215         return false;
1216 
1217     // If the splat value doesn't evenly divide the stride we're looking
1218     // for, there's no way that we can get the linear sequence we're
1219     // looking or.
1220     int64_t splatVal = splat->getSExtValue();
1221     if (splatVal == 0 || splatVal > stride || (stride % splatVal) != 0)
1222         return false;
1223 
1224     // Check to see if the other operand is a linear vector with stride
1225     // given by stride/splatVal.
1226     return lVectorIsLinear(op1, vectorLength, (int)(stride / splatVal), seenPhis);
1227 }
1228 
1229 /** Checks to see if (op0 << op1) is a linear vector where the result is a
1230     vector with values that increase by stride.
1231  */
lCheckShlForLinear(llvm::Value * op0,llvm::Value * op1,int vectorLength,int stride,std::vector<llvm::PHINode * > & seenPhis)1232 static bool lCheckShlForLinear(llvm::Value *op0, llvm::Value *op1, int vectorLength, int stride,
1233                                std::vector<llvm::PHINode *> &seenPhis) {
1234     // Is the second operand a constant integer value splatted across all of
1235     // the lanes?
1236     llvm::ConstantDataVector *cv = llvm::dyn_cast<llvm::ConstantDataVector>(op1);
1237     if (cv == NULL)
1238         return false;
1239 
1240     llvm::Constant *csplat = cv->getSplatValue();
1241     if (csplat == NULL)
1242         return false;
1243 
1244     llvm::ConstantInt *splat = llvm::dyn_cast<llvm::ConstantInt>(csplat);
1245     if (splat == NULL)
1246         return false;
1247 
1248     // If (1 << the splat value) doesn't evenly divide the stride we're
1249     // looking for, there's no way that we can get the linear sequence
1250     // we're looking or.
1251     int64_t equivalentMul = (1LL << splat->getSExtValue());
1252     if (equivalentMul > stride || (stride % equivalentMul) != 0)
1253         return false;
1254 
1255     // Check to see if the first operand is a linear vector with stride
1256     // given by stride/splatVal.
1257     return lVectorIsLinear(op0, vectorLength, (int)(stride / equivalentMul), seenPhis);
1258 }
1259 
1260 /** Given (op0 AND op1), try and see if we can determine if the result is a
1261     linear sequence with a step of "stride" between values.  Returns true
1262     if so and false otherwise.  This pattern comes up when accessing SOA
1263     data.
1264  */
lCheckAndForLinear(llvm::Value * op0,llvm::Value * op1,int vectorLength,int stride,std::vector<llvm::PHINode * > & seenPhis)1265 static bool lCheckAndForLinear(llvm::Value *op0, llvm::Value *op1, int vectorLength, int stride,
1266                                std::vector<llvm::PHINode *> &seenPhis) {
1267     // Require op1 to be a compile-time constant
1268     int64_t maskValue[ISPC_MAX_NVEC];
1269     int nElts;
1270     if (LLVMExtractVectorInts(op1, maskValue, &nElts) == false)
1271         return false;
1272     Assert(nElts == vectorLength);
1273 
1274     // Is op1 a smear of the same value across all lanes?  Give up if not.
1275     for (int i = 1; i < vectorLength; ++i)
1276         if (maskValue[i] != maskValue[0])
1277             return false;
1278 
1279     // If the op1 value isn't power of 2 minus one, then also give up.
1280     int64_t maskPlusOne = maskValue[0] + 1;
1281     bool isPowTwo = (maskPlusOne & (maskPlusOne - 1)) == 0;
1282     if (isPowTwo == false)
1283         return false;
1284 
1285     // The case we'll covert here is op0 being a linear vector with desired
1286     // stride, and where all of the values of op0, when divided by
1287     // maskPlusOne, have the same value.
1288     if (lVectorIsLinear(op0, vectorLength, stride, seenPhis) == false)
1289         return false;
1290 
1291     bool canAdd = true;
1292     bool isMult = lAllDivBaseEqual(op0, maskPlusOne, vectorLength, seenPhis, canAdd);
1293     return isMult;
1294 }
1295 
lVectorIsLinear(llvm::Value * v,int vectorLength,int stride,std::vector<llvm::PHINode * > & seenPhis)1296 static bool lVectorIsLinear(llvm::Value *v, int vectorLength, int stride, std::vector<llvm::PHINode *> &seenPhis) {
1297     // First try the easy case: if the values are all just constant
1298     // integers and have the expected stride between them, then we're done.
1299     llvm::ConstantDataVector *cv = llvm::dyn_cast<llvm::ConstantDataVector>(v);
1300     if (cv != NULL)
1301         return lVectorIsLinearConstantInts(cv, vectorLength, stride);
1302 
1303     llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
1304     if (bop != NULL) {
1305         // FIXME: is it right to pass the seenPhis to the all equal check as well??
1306         llvm::Value *op0 = bop->getOperand(0), *op1 = bop->getOperand(1);
1307         if ((bop->getOpcode() == llvm::Instruction::Add) || IsOrEquivalentToAdd(bop)) {
1308             // There are two cases to check if we have an add:
1309             //
1310             // programIndex + unif -> ascending linear seqeuence
1311             // unif + programIndex -> ascending linear sequence
1312             bool l0 = lVectorIsLinear(op0, vectorLength, stride, seenPhis);
1313             bool e1 = lVectorValuesAllEqual(op1, vectorLength, seenPhis);
1314             if (l0 && e1)
1315                 return true;
1316 
1317             bool e0 = lVectorValuesAllEqual(op0, vectorLength, seenPhis);
1318             bool l1 = lVectorIsLinear(op1, vectorLength, stride, seenPhis);
1319             return (e0 && l1);
1320         } else if (bop->getOpcode() == llvm::Instruction::Sub)
1321             // For subtraction, we only match:
1322             // programIndex - unif -> ascending linear seqeuence
1323             return (lVectorIsLinear(bop->getOperand(0), vectorLength, stride, seenPhis) &&
1324                     lVectorValuesAllEqual(bop->getOperand(1), vectorLength, seenPhis));
1325         else if (bop->getOpcode() == llvm::Instruction::Mul) {
1326             // Multiplies are a bit trickier, so are handled in a separate
1327             // function.
1328             bool m0 = lCheckMulForLinear(op0, op1, vectorLength, stride, seenPhis);
1329             if (m0)
1330                 return true;
1331             bool m1 = lCheckMulForLinear(op1, op0, vectorLength, stride, seenPhis);
1332             return m1;
1333         } else if (bop->getOpcode() == llvm::Instruction::Shl) {
1334             // Sometimes multiplies come in as shift lefts (especially in
1335             // LLVM 3.4+).
1336             bool linear = lCheckShlForLinear(op0, op1, vectorLength, stride, seenPhis);
1337             return linear;
1338         } else if (bop->getOpcode() == llvm::Instruction::And) {
1339             // Special case for some AND-related patterns that come up when
1340             // looping over SOA data
1341             bool linear = lCheckAndForLinear(op0, op1, vectorLength, stride, seenPhis);
1342             return linear;
1343         } else
1344             return false;
1345     }
1346 
1347     llvm::CastInst *ci = llvm::dyn_cast<llvm::CastInst>(v);
1348     if (ci != NULL)
1349         return lVectorIsLinear(ci->getOperand(0), vectorLength, stride, seenPhis);
1350 
1351     if (llvm::isa<llvm::CallInst>(v) || llvm::isa<llvm::LoadInst>(v))
1352         return false;
1353 
1354     llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(v);
1355     if (phi != NULL) {
1356         for (unsigned int i = 0; i < seenPhis.size(); ++i)
1357             if (seenPhis[i] == phi)
1358                 return true;
1359 
1360         seenPhis.push_back(phi);
1361 
1362         unsigned int numIncoming = phi->getNumIncomingValues();
1363         // Check all of the incoming values: if all of them are all equal,
1364         // then we're good.
1365         for (unsigned int i = 0; i < numIncoming; ++i) {
1366             if (!lVectorIsLinear(phi->getIncomingValue(i), vectorLength, stride, seenPhis)) {
1367                 seenPhis.pop_back();
1368                 return false;
1369             }
1370         }
1371 
1372         seenPhis.pop_back();
1373         return true;
1374     }
1375 
1376     // TODO: is any reason to worry about these?
1377     if (llvm::isa<llvm::InsertElementInst>(v))
1378         return false;
1379 
1380     // TODO: we could also handle shuffles, but we haven't yet seen any
1381     // cases where doing so would detect cases where actually have a linear
1382     // vector.
1383     llvm::ShuffleVectorInst *shuffle = llvm::dyn_cast<llvm::ShuffleVectorInst>(v);
1384     if (shuffle != NULL)
1385         return false;
1386 
1387 #if 0
1388     fprintf(stderr, "linear check: ");
1389     v->dump();
1390     fprintf(stderr, "\n");
1391     llvm::Instruction *inst = llvm::dyn_cast<llvm::Instruction>(v);
1392     if (inst) {
1393         inst->getParent()->dump();
1394         fprintf(stderr, "\n");
1395         fprintf(stderr, "\n");
1396     }
1397 #endif
1398 
1399     return false;
1400 }
1401 
1402 /** Given vector of integer-typed values, see if the elements of the array
1403     have a step of 'stride' between their values.  This function tries to
1404     handle as many possibilities as possible, including things like all
1405     elements equal to some non-constant value plus an integer offset, etc.
1406 */
LLVMVectorIsLinear(llvm::Value * v,int stride)1407 bool LLVMVectorIsLinear(llvm::Value *v, int stride) {
1408 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
1409     llvm::FixedVectorType *vt = llvm::dyn_cast<llvm::FixedVectorType>(v->getType());
1410 #else
1411     llvm::VectorType *vt = llvm::dyn_cast<llvm::VectorType>(v->getType());
1412 #endif
1413     Assert(vt != NULL);
1414     int vectorLength = vt->getNumElements();
1415 
1416     std::vector<llvm::PHINode *> seenPhis;
1417     bool linear = lVectorIsLinear(v, vectorLength, stride, seenPhis);
1418     Debug(SourcePos(), "LLVMVectorIsLinear(%s) -> %s.", v->getName().str().c_str(), linear ? "true" : "false");
1419 #ifndef ISPC_NO_DUMPS
1420     if (g->debugPrint)
1421         LLVMDumpValue(v);
1422 #endif
1423 
1424     return linear;
1425 }
1426 
1427 #ifndef ISPC_NO_DUMPS
lDumpValue(llvm::Value * v,std::set<llvm::Value * > & done)1428 static void lDumpValue(llvm::Value *v, std::set<llvm::Value *> &done) {
1429     if (done.find(v) != done.end())
1430         return;
1431 
1432     llvm::Instruction *inst = llvm::dyn_cast<llvm::Instruction>(v);
1433     if (done.size() > 0 && inst == NULL)
1434         return;
1435 
1436     fprintf(stderr, "  ");
1437     //v->dump();
1438     done.insert(v);
1439 
1440     if (inst == NULL)
1441         return;
1442 
1443     for (unsigned i = 0; i < inst->getNumOperands(); ++i)
1444         lDumpValue(inst->getOperand(i), done);
1445 }
1446 
LLVMDumpValue(llvm::Value * v)1447 void LLVMDumpValue(llvm::Value *v) {
1448     std::set<llvm::Value *> done;
1449     lDumpValue(v, done);
1450     fprintf(stderr, "----\n");
1451 }
1452 #endif
1453 
lExtractFirstVectorElement(llvm::Value * v,std::map<llvm::PHINode *,llvm::PHINode * > & phiMap)1454 static llvm::Value *lExtractFirstVectorElement(llvm::Value *v, std::map<llvm::PHINode *, llvm::PHINode *> &phiMap) {
1455 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
1456     llvm::FixedVectorType *vt = llvm::dyn_cast<llvm::FixedVectorType>(v->getType());
1457 #else
1458     llvm::VectorType *vt = llvm::dyn_cast<llvm::VectorType>(v->getType());
1459 #endif
1460     Assert(vt != NULL);
1461 
1462     // First, handle various constant types; do the extraction manually, as
1463     // appropriate.
1464     if (llvm::isa<llvm::ConstantAggregateZero>(v) == true) {
1465         return llvm::Constant::getNullValue(vt->getElementType());
1466     }
1467     if (llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(v)) {
1468         return cv->getOperand(0);
1469     }
1470     if (llvm::ConstantDataVector *cdv = llvm::dyn_cast<llvm::ConstantDataVector>(v))
1471         return cdv->getElementAsConstant(0);
1472 
1473     // Otherwise, all that we should have at this point is an instruction
1474     // of some sort
1475     Assert(llvm::isa<llvm::Constant>(v) == false);
1476     Assert(llvm::isa<llvm::Instruction>(v) == true);
1477 
1478     std::string newName = v->getName().str() + std::string(".elt0");
1479 
1480     // Rewrite regular binary operators and casts to the scalarized
1481     // equivalent.
1482     llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
1483     if (bop != NULL) {
1484         llvm::Value *v0 = lExtractFirstVectorElement(bop->getOperand(0), phiMap);
1485         llvm::Value *v1 = lExtractFirstVectorElement(bop->getOperand(1), phiMap);
1486         Assert(v0 != NULL);
1487         Assert(v1 != NULL);
1488         // Note that the new binary operator is inserted immediately before
1489         // the previous vector one
1490         return llvm::BinaryOperator::Create(bop->getOpcode(), v0, v1, newName, bop);
1491     }
1492 
1493     llvm::CastInst *cast = llvm::dyn_cast<llvm::CastInst>(v);
1494     if (cast != NULL) {
1495         llvm::Value *v = lExtractFirstVectorElement(cast->getOperand(0), phiMap);
1496         // Similarly, the equivalent scalar cast instruction goes right
1497         // before the vector cast
1498         return llvm::CastInst::Create(cast->getOpcode(), v, vt->getElementType(), newName, cast);
1499     }
1500 
1501     llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(v);
1502     if (phi != NULL) {
1503         // For PHI notes, recursively scalarize them.
1504         if (phiMap.find(phi) != phiMap.end())
1505             return phiMap[phi];
1506 
1507         // We need to create the new scalar PHI node immediately, though,
1508         // and put it in the map<>, so that if we come back to this node
1509         // via a recursive lExtractFirstVectorElement() call, then we can
1510         // return the pointer and not get stuck in an infinite loop.
1511         //
1512         // The insertion point for the new phi node also has to be the
1513         // start of the bblock of the original phi node.
1514 
1515         llvm::Instruction *phiInsertPos = &*(phi->getParent()->begin());
1516         llvm::PHINode *scalarPhi =
1517             llvm::PHINode::Create(vt->getElementType(), phi->getNumIncomingValues(), newName, phiInsertPos);
1518         phiMap[phi] = scalarPhi;
1519 
1520         for (unsigned i = 0; i < phi->getNumIncomingValues(); ++i) {
1521             llvm::Value *v = lExtractFirstVectorElement(phi->getIncomingValue(i), phiMap);
1522             scalarPhi->addIncoming(v, phi->getIncomingBlock(i));
1523         }
1524 
1525         return scalarPhi;
1526     }
1527 
1528     // We should consider "shuffle" case and "insertElement" case separately.
1529     // For example we can have shuffle(mul, undef, zero) but function
1530     // "LLVMFlattenInsertChain" can handle only case shuffle(insertElement, undef, zero).
1531     // Also if we have insertElement under shuffle we will handle it the next call of
1532     // "lExtractFirstVectorElement" function.
1533     if (llvm::isa<llvm::ShuffleVectorInst>(v)) {
1534         llvm::ShuffleVectorInst *shuf = llvm::dyn_cast<llvm::ShuffleVectorInst>(v);
1535 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
1536         llvm::Value *indices = shuf->getShuffleMaskForBitcode();
1537 #else
1538         llvm::Value *indices = shuf->getOperand(2);
1539 #endif
1540         if (llvm::isa<llvm::ConstantAggregateZero>(indices)) {
1541             return lExtractFirstVectorElement(shuf->getOperand(0), phiMap);
1542         }
1543     }
1544 
1545     // If we have a chain of insertelement instructions, then we can just
1546     // flatten them out and grab the value for the first one.
1547     if (llvm::isa<llvm::InsertElementInst>(v)) {
1548         return LLVMFlattenInsertChain(v, vt->getNumElements(), false);
1549     }
1550 
1551     // Worst case, for everything else, just do a regular extract element
1552     // instruction, which we insert immediately after the instruction we
1553     // have here.
1554     llvm::Instruction *insertAfter = llvm::dyn_cast<llvm::Instruction>(v);
1555     Assert(insertAfter != NULL);
1556     llvm::Instruction *ee = llvm::ExtractElementInst::Create(v, LLVMInt32(0), "first_elt", (llvm::Instruction *)NULL);
1557     ee->insertAfter(insertAfter);
1558     return ee;
1559 }
1560 
LLVMExtractFirstVectorElement(llvm::Value * v)1561 llvm::Value *LLVMExtractFirstVectorElement(llvm::Value *v) {
1562     std::map<llvm::PHINode *, llvm::PHINode *> phiMap;
1563     llvm::Value *ret = lExtractFirstVectorElement(v, phiMap);
1564     return ret;
1565 }
1566 
1567 /** Given two vectors of the same type, concatenate them into a vector that
1568     has twice as many elements, where the first half has the elements from
1569     the first vector and the second half has the elements from the second
1570     vector.
1571  */
LLVMConcatVectors(llvm::Value * v1,llvm::Value * v2,llvm::Instruction * insertBefore)1572 llvm::Value *LLVMConcatVectors(llvm::Value *v1, llvm::Value *v2, llvm::Instruction *insertBefore) {
1573     Assert(v1->getType() == v2->getType());
1574 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
1575     llvm::FixedVectorType *vt = llvm::dyn_cast<llvm::FixedVectorType>(v1->getType());
1576 #else
1577     llvm::VectorType *vt = llvm::dyn_cast<llvm::VectorType>(v1->getType());
1578 #endif
1579     Assert(vt != NULL);
1580 
1581     int32_t identity[ISPC_MAX_NVEC];
1582     int resultSize = 2 * vt->getNumElements();
1583     Assert(resultSize <= ISPC_MAX_NVEC);
1584     for (int i = 0; i < resultSize; ++i)
1585         identity[i] = i;
1586 
1587     return LLVMShuffleVectors(v1, v2, identity, resultSize, insertBefore);
1588 }
1589 
1590 /** Shuffle two vectors together with a ShuffleVectorInst, returning a
1591     vector with shufSize elements, where the shuf[] array offsets are used
1592     to determine which element from the two given vectors is used for each
1593     result element. */
LLVMShuffleVectors(llvm::Value * v1,llvm::Value * v2,int32_t shuf[],int shufSize,llvm::Instruction * insertBefore)1594 llvm::Value *LLVMShuffleVectors(llvm::Value *v1, llvm::Value *v2, int32_t shuf[], int shufSize,
1595                                 llvm::Instruction *insertBefore) {
1596     std::vector<llvm::Constant *> shufVec;
1597     for (int i = 0; i < shufSize; ++i) {
1598         if (shuf[i] == -1)
1599             shufVec.push_back(llvm::UndefValue::get(LLVMTypes::Int32Type));
1600         else
1601             shufVec.push_back(LLVMInt32(shuf[i]));
1602     }
1603 
1604     llvm::ArrayRef<llvm::Constant *> aref(&shufVec[0], &shufVec[shufSize]);
1605     llvm::Value *vec = llvm::ConstantVector::get(aref);
1606 
1607     return new llvm::ShuffleVectorInst(v1, v2, vec, "shuffle", insertBefore);
1608 }
1609 
1610 #ifdef ISPC_GENX_ENABLED
lIsSVMLoad(llvm::Instruction * inst)1611 bool lIsSVMLoad(llvm::Instruction *inst) {
1612     Assert(inst);
1613 
1614     switch (llvm::GenXIntrinsic::getGenXIntrinsicID(inst)) {
1615     case llvm::GenXIntrinsic::genx_svm_block_ld:
1616     case llvm::GenXIntrinsic::genx_svm_block_ld_unaligned:
1617     case llvm::GenXIntrinsic::genx_svm_gather:
1618         return true;
1619     default:
1620         return false;
1621     }
1622 }
1623 
lGetAddressSpace(llvm::Value * v,std::set<llvm::Value * > & done,std::set<AddressSpace> & addrSpaceVec)1624 void lGetAddressSpace(llvm::Value *v, std::set<llvm::Value *> &done, std::set<AddressSpace> &addrSpaceVec) {
1625     if (done.find(v) != done.end()) {
1626         if (llvm::isa<llvm::PointerType>(v->getType()))
1627             addrSpaceVec.insert(AddressSpace::External);
1628         return;
1629     }
1630     // Found global value
1631     if (llvm::isa<llvm::GlobalValue>(v)) {
1632         addrSpaceVec.insert(AddressSpace::Global);
1633         return;
1634     }
1635 
1636     llvm::Instruction *inst = llvm::dyn_cast<llvm::Instruction>(v);
1637     bool isConstExpr = false;
1638     if (inst == NULL) {
1639         // Case when GEP is constant expression
1640         if (llvm::ConstantExpr *constExpr = llvm::dyn_cast<llvm::ConstantExpr>(v)) {
1641             // This instruction isn't inserted anywhere, so delete it when done
1642             inst = constExpr->getAsInstruction();
1643             isConstExpr = true;
1644         }
1645     }
1646 
1647     if (done.size() > 0 && inst == NULL) {
1648         // Found external pointer like "float* %aFOO"
1649         if (llvm::isa<llvm::PointerType>(v->getType()))
1650             addrSpaceVec.insert(AddressSpace::External);
1651         return;
1652     }
1653 
1654     done.insert(v);
1655 
1656     // Found value allocated on stack like "%val = alloca [16 x float]"
1657     if (llvm::isa<llvm::AllocaInst>(v)) {
1658         addrSpaceVec.insert(AddressSpace::Local);
1659         return;
1660     }
1661 
1662     if (inst == NULL || llvm::isa<llvm::CallInst>(v)) {
1663         if (llvm::isa<llvm::PointerType>(v->getType()) || (inst && lIsSVMLoad(inst)))
1664             addrSpaceVec.insert(AddressSpace::External);
1665         return;
1666     }
1667     // For GEP we check only pointer operand, for the rest check all
1668     if (llvm::isa<llvm::GetElementPtrInst>(inst)) {
1669         lGetAddressSpace(inst->getOperand(0), done, addrSpaceVec);
1670     } else {
1671         for (unsigned i = 0; i < inst->getNumOperands(); ++i) {
1672             lGetAddressSpace(inst->getOperand(i), done, addrSpaceVec);
1673         }
1674     }
1675 
1676     if (isConstExpr) {
1677         // This is the only return point that constant expression instruction
1678         // can reach, drop all references here
1679         inst->dropAllReferences();
1680     }
1681 }
1682 
1683 /** This routine attempts to determine if the given value is pointing to
1684     stack-allocated memory. The basic strategy is to traverse through the
1685     operands and see if the pointer originally comes from an AllocaInst.
1686 */
GetAddressSpace(llvm::Value * v)1687 AddressSpace GetAddressSpace(llvm::Value *v) {
1688     std::set<llvm::Value *> done;
1689     std::set<AddressSpace> addrSpaceVec;
1690     lGetAddressSpace(v, done, addrSpaceVec);
1691     if (addrSpaceVec.find(AddressSpace::External) != addrSpaceVec.end()) {
1692         return AddressSpace::External;
1693     }
1694     if (addrSpaceVec.find(AddressSpace::Global) != addrSpaceVec.end()) {
1695         return AddressSpace::Global;
1696     }
1697     return AddressSpace::Local;
1698 }
1699 #endif
1700 } // namespace ispc
1701