1 /*
2 Copyright (c) 2010-2021, Intel Corporation
3 All rights reserved.
4
5 Redistribution and use in source and binary forms, with or without
6 modification, are permitted provided that the following conditions are
7 met:
8
9 * Redistributions of source code must retain the above copyright
10 notice, this list of conditions and the following disclaimer.
11
12 * Redistributions in binary form must reproduce the above copyright
13 notice, this list of conditions and the following disclaimer in the
14 documentation and/or other materials provided with the distribution.
15
16 * Neither the name of Intel Corporation nor the names of its
17 contributors may be used to endorse or promote products derived from
18 this software without specific prior written permission.
19
20
21 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
22 IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
23 TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
24 PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
25 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 /** @file expr.cpp
35 @brief Implementations of expression classes
36 */
37
38 #include "expr.h"
39 #include "ast.h"
40 #include "ctx.h"
41 #include "llvmutil.h"
42 #include "module.h"
43 #include "sym.h"
44 #include "type.h"
45 #include "util.h"
46
47 #ifndef _MSC_VER
48 #include <inttypes.h>
49 #endif
50 #ifndef PRId64
51 #define PRId64 "lld"
52 #endif
53 #ifndef PRIu64
54 #define PRIu64 "llu"
55 #endif
56
57 #include <list>
58 #include <set>
59 #include <stdio.h>
60
61 #include <llvm/ExecutionEngine/GenericValue.h>
62 #include <llvm/IR/CallingConv.h>
63 #include <llvm/IR/DerivedTypes.h>
64 #include <llvm/IR/Function.h>
65 #include <llvm/IR/InstIterator.h>
66 #include <llvm/IR/Instructions.h>
67 #include <llvm/IR/LLVMContext.h>
68 #include <llvm/IR/Module.h>
69 #include <llvm/IR/Type.h>
70
71 using namespace ispc;
72
73 /////////////////////////////////////////////////////////////////////////////////////
74 // Expr
75
GetLValue(FunctionEmitContext * ctx) const76 llvm::Value *Expr::GetLValue(FunctionEmitContext *ctx) const {
77 // Expressions that can't provide an lvalue can just return NULL
78 return NULL;
79 }
80
GetLValueType() const81 const Type *Expr::GetLValueType() const {
82 // This also only needs to be overrided by Exprs that implement the
83 // GetLValue() method.
84 return NULL;
85 }
86
GetStorageConstant(const Type * type) const87 std::pair<llvm::Constant *, bool> Expr::GetStorageConstant(const Type *type) const { return GetConstant(type); }
GetConstant(const Type * type) const88 std::pair<llvm::Constant *, bool> Expr::GetConstant(const Type *type) const {
89 // The default is failure; just return NULL
90 return std::pair<llvm::Constant *, bool>(NULL, false);
91 }
92
GetBaseSymbol() const93 Symbol *Expr::GetBaseSymbol() const {
94 // Not all expressions can do this, so provide a generally-useful
95 // default implementation.
96 return NULL;
97 }
98
HasAmbiguousVariability(std::vector<const Expr * > & warn) const99 bool Expr::HasAmbiguousVariability(std::vector<const Expr *> &warn) const { return false; }
100
101 #if 0
102 /** If a conversion from 'fromAtomicType' to 'toAtomicType' may cause lost
103 precision, issue a warning. Don't warn for conversions to bool and
104 conversions between signed and unsigned integers of the same size.
105 */
106 static void
107 lMaybeIssuePrecisionWarning(const AtomicType *toAtomicType,
108 const AtomicType *fromAtomicType,
109 SourcePos pos, const char *errorMsgBase) {
110 switch (toAtomicType->basicType) {
111 case AtomicType::TYPE_BOOL:
112 case AtomicType::TYPE_INT8:
113 case AtomicType::TYPE_UINT8:
114 case AtomicType::TYPE_INT16:
115 case AtomicType::TYPE_UINT16:
116 case AtomicType::TYPE_INT32:
117 case AtomicType::TYPE_UINT32:
118 case AtomicType::TYPE_FLOAT:
119 case AtomicType::TYPE_INT64:
120 case AtomicType::TYPE_UINT64:
121 case AtomicType::TYPE_DOUBLE:
122 if ((int)toAtomicType->basicType < (int)fromAtomicType->basicType &&
123 toAtomicType->basicType != AtomicType::TYPE_BOOL &&
124 !(toAtomicType->basicType == AtomicType::TYPE_INT8 &&
125 fromAtomicType->basicType == AtomicType::TYPE_UINT8) &&
126 !(toAtomicType->basicType == AtomicType::TYPE_INT16 &&
127 fromAtomicType->basicType == AtomicType::TYPE_UINT16) &&
128 !(toAtomicType->basicType == AtomicType::TYPE_INT32 &&
129 fromAtomicType->basicType == AtomicType::TYPE_UINT32) &&
130 !(toAtomicType->basicType == AtomicType::TYPE_INT64 &&
131 fromAtomicType->basicType == AtomicType::TYPE_UINT64))
132 Warning(pos, "Conversion from type \"%s\" to type \"%s\" for %s"
133 " may lose information.",
134 fromAtomicType->GetString().c_str(), toAtomicType->GetString().c_str(),
135 errorMsgBase);
136 break;
137 default:
138 FATAL("logic error in lMaybeIssuePrecisionWarning()");
139 }
140 }
141 #endif
142
143 ///////////////////////////////////////////////////////////////////////////
144
lArrayToPointer(Expr * expr)145 static Expr *lArrayToPointer(Expr *expr) {
146 Assert(expr != NULL);
147 AssertPos(expr->pos, CastType<ArrayType>(expr->GetType()));
148
149 Expr *zero = new ConstExpr(AtomicType::UniformInt32, 0, expr->pos);
150 Expr *index = new IndexExpr(expr, zero, expr->pos);
151 Expr *addr = new AddressOfExpr(index, expr->pos);
152 addr = TypeCheck(addr);
153 Assert(addr != NULL);
154 addr = Optimize(addr);
155 Assert(addr != NULL);
156 return addr;
157 }
158
lIsAllIntZeros(Expr * expr)159 static bool lIsAllIntZeros(Expr *expr) {
160 const Type *type = expr->GetType();
161 if (type == NULL || type->IsIntType() == false)
162 return false;
163
164 ConstExpr *ce = llvm::dyn_cast<ConstExpr>(expr);
165 if (ce == NULL)
166 return false;
167
168 uint64_t vals[ISPC_MAX_NVEC];
169 int count = ce->GetValues(vals);
170 if (count == 1)
171 return (vals[0] == 0);
172 else {
173 for (int i = 0; i < count; ++i)
174 if (vals[i] != 0)
175 return false;
176 }
177 return true;
178 }
179
lDoTypeConv(const Type * fromType,const Type * toType,Expr ** expr,bool failureOk,const char * errorMsgBase,SourcePos pos)180 static bool lDoTypeConv(const Type *fromType, const Type *toType, Expr **expr, bool failureOk, const char *errorMsgBase,
181 SourcePos pos) {
182 /* This function is way too long and complex. Is type conversion stuff
183 always this messy, or can this be cleaned up somehow? */
184 AssertPos(pos, failureOk || errorMsgBase != NULL);
185
186 if (toType == NULL || fromType == NULL)
187 return false;
188
189 // The types are equal; there's nothing to do
190 if (Type::Equal(toType, fromType))
191 return true;
192
193 if (fromType->IsVoidType()) {
194 if (!failureOk)
195 Error(pos, "Can't convert from \"void\" to \"%s\" for %s.", toType->GetString().c_str(), errorMsgBase);
196 return false;
197 }
198
199 if (toType->IsVoidType()) {
200 if (!failureOk)
201 Error(pos, "Can't convert type \"%s\" to \"void\" for %s.", fromType->GetString().c_str(), errorMsgBase);
202 return false;
203 }
204
205 if (CastType<FunctionType>(fromType)) {
206 if (CastType<PointerType>(toType) != NULL) {
207 // Convert function type to pointer to function type
208 if (expr != NULL) {
209 Expr *aoe = new AddressOfExpr(*expr, (*expr)->pos);
210 if (lDoTypeConv(aoe->GetType(), toType, &aoe, failureOk, errorMsgBase, pos)) {
211 *expr = aoe;
212 return true;
213 }
214 } else
215 return lDoTypeConv(PointerType::GetUniform(fromType), toType, NULL, failureOk, errorMsgBase, pos);
216 } else {
217 if (!failureOk)
218 Error(pos, "Can't convert function type \"%s\" to \"%s\" for %s.", fromType->GetString().c_str(),
219 toType->GetString().c_str(), errorMsgBase);
220 return false;
221 }
222 }
223 if (CastType<FunctionType>(toType)) {
224 if (!failureOk)
225 Error(pos,
226 "Can't convert from type \"%s\" to function type \"%s\" "
227 "for %s.",
228 fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
229 return false;
230 }
231
232 if ((toType->GetSOAWidth() > 0 || fromType->GetSOAWidth() > 0) &&
233 Type::Equal(toType->GetAsUniformType(), fromType->GetAsUniformType()) &&
234 toType->GetSOAWidth() != fromType->GetSOAWidth()) {
235 if (!failureOk)
236 Error(pos,
237 "Can't convert between types \"%s\" and \"%s\" with "
238 "different SOA widths for %s.",
239 fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
240 return false;
241 }
242
243 const ArrayType *toArrayType = CastType<ArrayType>(toType);
244 const ArrayType *fromArrayType = CastType<ArrayType>(fromType);
245 const VectorType *toVectorType = CastType<VectorType>(toType);
246 const VectorType *fromVectorType = CastType<VectorType>(fromType);
247 const StructType *toStructType = CastType<StructType>(toType);
248 const StructType *fromStructType = CastType<StructType>(fromType);
249 const EnumType *toEnumType = CastType<EnumType>(toType);
250 const EnumType *fromEnumType = CastType<EnumType>(fromType);
251 const AtomicType *toAtomicType = CastType<AtomicType>(toType);
252 const AtomicType *fromAtomicType = CastType<AtomicType>(fromType);
253 const PointerType *fromPointerType = CastType<PointerType>(fromType);
254 const PointerType *toPointerType = CastType<PointerType>(toType);
255
256 // Do this early, since for the case of a conversion like
257 // "float foo[10]" -> "float * uniform foo", we have what's seemingly
258 // a varying to uniform conversion (but not really)
259 if (fromArrayType != NULL && toPointerType != NULL) {
260 // can convert any array to a void pointer (both uniform and
261 // varying).
262 if (PointerType::IsVoidPointer(toPointerType))
263 goto typecast_ok;
264
265 // array to pointer to array element type
266 const Type *eltType = fromArrayType->GetElementType();
267 if (toPointerType->GetBaseType()->IsConstType())
268 eltType = eltType->GetAsConstType();
269
270 PointerType pt(eltType, toPointerType->GetVariability(), toPointerType->IsConstType());
271 if (Type::Equal(toPointerType, &pt))
272 goto typecast_ok;
273 else {
274 if (!failureOk)
275 Error(pos,
276 "Can't convert from incompatible array type \"%s\" "
277 "to pointer type \"%s\" for %s.",
278 fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
279 return false;
280 }
281 }
282
283 if (toType->IsUniformType() && fromType->IsVaryingType()) {
284 if (!failureOk)
285 Error(pos, "Can't convert from type \"%s\" to type \"%s\" for %s.", fromType->GetString().c_str(),
286 toType->GetString().c_str(), errorMsgBase);
287 return false;
288 }
289
290 if (fromPointerType != NULL) {
291 if (CastType<AtomicType>(toType) != NULL && toType->IsBoolType())
292 // Allow implicit conversion of pointers to bools
293 goto typecast_ok;
294
295 if (toArrayType != NULL && Type::Equal(fromType->GetBaseType(), toArrayType->GetElementType())) {
296 // Can convert pointers to arrays of the same type
297 goto typecast_ok;
298 }
299 if (toPointerType == NULL) {
300 if (!failureOk)
301 Error(pos,
302 "Can't convert between from pointer type "
303 "\"%s\" to non-pointer type \"%s\" for %s.",
304 fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
305 return false;
306 } else if (fromPointerType->IsSlice() == true && toPointerType->IsSlice() == false) {
307 if (!failureOk)
308 Error(pos,
309 "Can't convert from pointer to SOA type "
310 "\"%s\" to pointer to non-SOA type \"%s\" for %s.",
311 fromPointerType->GetAsNonSlice()->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
312 return false;
313 } else if (PointerType::IsVoidPointer(toPointerType)) {
314 if (fromPointerType->GetBaseType()->IsConstType() && !(toPointerType->GetBaseType()->IsConstType())) {
315 if (!failureOk)
316 Error(pos, "Can't convert pointer to const \"%s\" to void pointer.",
317 fromPointerType->GetString().c_str());
318 return false;
319 }
320 // any pointer type can be converted to a void *
321 // ...almost. #731
322 goto typecast_ok;
323 } else if (PointerType::IsVoidPointer(fromPointerType) && expr != NULL &&
324 llvm::dyn_cast<NullPointerExpr>(*expr) != NULL) {
325 // and a NULL convert to any other pointer type
326 goto typecast_ok;
327 } else if (!Type::Equal(fromPointerType->GetBaseType(), toPointerType->GetBaseType()) &&
328 !Type::Equal(fromPointerType->GetBaseType()->GetAsConstType(), toPointerType->GetBaseType())) {
329 // for const * -> * conversion, print warning.
330 if (Type::EqualIgnoringConst(fromPointerType->GetBaseType(), toPointerType->GetBaseType())) {
331 if (!Type::Equal(fromPointerType->GetBaseType()->GetAsConstType(), toPointerType->GetBaseType())) {
332 Warning(pos,
333 "Converting from const pointer type \"%s\" to "
334 "pointer type \"%s\" for %s discards const qualifier.",
335 fromPointerType->GetString().c_str(), toPointerType->GetString().c_str(), errorMsgBase);
336 }
337 } else {
338 if (!failureOk) {
339 Error(pos,
340 "Can't convert from pointer type \"%s\" to "
341 "incompatible pointer type \"%s\" for %s.",
342 fromPointerType->GetString().c_str(), toPointerType->GetString().c_str(), errorMsgBase);
343 }
344 return false;
345 }
346 }
347
348 if (toType->IsVaryingType() && fromType->IsUniformType())
349 goto typecast_ok;
350
351 if (toPointerType->IsSlice() == true && fromPointerType->IsSlice() == false)
352 goto typecast_ok;
353
354 // Otherwise there's nothing to do
355 return true;
356 }
357
358 if (toPointerType != NULL && fromAtomicType != NULL && fromAtomicType->IsIntType() && expr != NULL &&
359 lIsAllIntZeros(*expr)) {
360 // We have a zero-valued integer expression, which can also be
361 // treated as a NULL pointer that can be converted to any other
362 // pointer type.
363 Expr *npe = new NullPointerExpr(pos);
364 if (lDoTypeConv(PointerType::Void, toType, &npe, failureOk, errorMsgBase, pos)) {
365 *expr = npe;
366 return true;
367 }
368 return false;
369 }
370
371 // Need to check this early, since otherwise the [sic] "unbound"
372 // variability of SOA struct types causes things to get messy if that
373 // hasn't been detected...
374 if (toStructType && fromStructType && (toStructType->GetSOAWidth() != fromStructType->GetSOAWidth())) {
375 if (!failureOk)
376 Error(pos,
377 "Can't convert between incompatible struct types \"%s\" "
378 "and \"%s\" for %s.",
379 fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
380 return false;
381 }
382
383 // Convert from type T -> const T; just return a TypeCast expr, which
384 // can handle this
385 if (Type::EqualIgnoringConst(toType, fromType) && toType->IsConstType() == true && fromType->IsConstType() == false)
386 goto typecast_ok;
387
388 if (CastType<ReferenceType>(fromType)) {
389 if (CastType<ReferenceType>(toType)) {
390 // Convert from a reference to a type to a const reference to a type;
391 // this is handled by TypeCastExpr
392 if (Type::Equal(toType->GetReferenceTarget(), fromType->GetReferenceTarget()->GetAsConstType()))
393 goto typecast_ok;
394
395 const ArrayType *atFrom = CastType<ArrayType>(fromType->GetReferenceTarget());
396 const ArrayType *atTo = CastType<ArrayType>(toType->GetReferenceTarget());
397
398 if (atFrom != NULL && atTo != NULL && Type::Equal(atFrom->GetElementType(), atTo->GetElementType())) {
399 goto typecast_ok;
400 } else {
401 if (!failureOk)
402 Error(pos,
403 "Can't convert between incompatible reference types \"%s\" "
404 "and \"%s\" for %s.",
405 fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
406 return false;
407 }
408 } else {
409 // convert from a reference T -> T
410 if (expr != NULL) {
411 Expr *drExpr = new RefDerefExpr(*expr, pos);
412 if (lDoTypeConv(drExpr->GetType(), toType, &drExpr, failureOk, errorMsgBase, pos) == true) {
413 *expr = drExpr;
414 return true;
415 }
416 return false;
417 } else
418 return lDoTypeConv(fromType->GetReferenceTarget(), toType, NULL, failureOk, errorMsgBase, pos);
419 }
420 } else if (CastType<ReferenceType>(toType)) {
421 // T -> reference T
422 if (expr != NULL) {
423 Expr *rExpr = new ReferenceExpr(*expr, pos);
424 if (lDoTypeConv(rExpr->GetType(), toType, &rExpr, failureOk, errorMsgBase, pos) == true) {
425 *expr = rExpr;
426 return true;
427 }
428 return false;
429 } else {
430 ReferenceType rt(fromType);
431 return lDoTypeConv(&rt, toType, NULL, failureOk, errorMsgBase, pos);
432 }
433 } else if (Type::Equal(toType, fromType->GetAsNonConstType()))
434 // convert: const T -> T (as long as T isn't a reference)
435 goto typecast_ok;
436
437 fromType = fromType->GetReferenceTarget();
438 toType = toType->GetReferenceTarget();
439 if (toArrayType && fromArrayType) {
440 if (Type::Equal(toArrayType->GetElementType(), fromArrayType->GetElementType())) {
441 // the case of different element counts should have returned
442 // successfully earlier, yes??
443 AssertPos(pos, toArrayType->GetElementCount() != fromArrayType->GetElementCount());
444 goto typecast_ok;
445 } else if (Type::Equal(toArrayType->GetElementType(), fromArrayType->GetElementType()->GetAsConstType())) {
446 // T[x] -> const T[x]
447 goto typecast_ok;
448 } else {
449 if (!failureOk)
450 Error(pos, "Array type \"%s\" can't be converted to type \"%s\" for %s.", fromType->GetString().c_str(),
451 toType->GetString().c_str(), errorMsgBase);
452 return false;
453 }
454 }
455
456 if (toVectorType && fromVectorType) {
457 // converting e.g. int<n> -> float<n>
458 if (fromVectorType->GetElementCount() != toVectorType->GetElementCount()) {
459 if (!failureOk)
460 Error(pos,
461 "Can't convert between differently sized vector types "
462 "\"%s\" -> \"%s\" for %s.",
463 fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
464 return false;
465 }
466 goto typecast_ok;
467 }
468
469 if (toStructType && fromStructType) {
470 if (!Type::Equal(toStructType->GetAsUniformType()->GetAsConstType(),
471 fromStructType->GetAsUniformType()->GetAsConstType())) {
472 if (!failureOk)
473 Error(pos,
474 "Can't convert between different struct types "
475 "\"%s\" and \"%s\" for %s.",
476 fromStructType->GetString().c_str(), toStructType->GetString().c_str(), errorMsgBase);
477 return false;
478 }
479 goto typecast_ok;
480 }
481
482 if (toEnumType != NULL && fromEnumType != NULL) {
483 // No implicit conversions between different enum types
484 if (!Type::EqualIgnoringConst(toEnumType->GetAsUniformType(), fromEnumType->GetAsUniformType())) {
485 if (!failureOk)
486 Error(pos,
487 "Can't convert between different enum types "
488 "\"%s\" and \"%s\" for %s",
489 fromEnumType->GetString().c_str(), toEnumType->GetString().c_str(), errorMsgBase);
490 return false;
491 }
492 goto typecast_ok;
493 }
494
495 // enum -> atomic (integer, generally...) is always ok
496 if (fromEnumType != NULL) {
497 // Cannot convert to anything other than atomic
498 if (toAtomicType == NULL && toVectorType == NULL) {
499 if (!failureOk)
500 Error(pos,
501 "Type conversion from \"%s\" to \"%s\" for %s is not "
502 "possible.",
503 fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
504 return false;
505 }
506 goto typecast_ok;
507 }
508
509 // from here on out, the from type can only be atomic something or
510 // other...
511 if (fromAtomicType == NULL) {
512 if (!failureOk)
513 Error(pos,
514 "Type conversion from \"%s\" to \"%s\" for %s is not "
515 "possible.",
516 fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
517 return false;
518 }
519
520 // scalar -> short-vector conversions
521 if (toVectorType != NULL && (fromType->GetSOAWidth() == toType->GetSOAWidth()))
522 goto typecast_ok;
523
524 // ok, it better be a scalar->scalar conversion of some sort by now
525 if (toAtomicType == NULL) {
526 if (!failureOk)
527 Error(pos,
528 "Type conversion from \"%s\" to \"%s\" for %s is "
529 "not possible",
530 fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
531 return false;
532 }
533
534 if (fromType->GetSOAWidth() != toType->GetSOAWidth()) {
535 if (!failureOk)
536 Error(pos,
537 "Can't convert between types \"%s\" and \"%s\" with "
538 "different SOA widths for %s.",
539 fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
540 return false;
541 }
542
543 typecast_ok:
544 if (expr != NULL)
545 *expr = new TypeCastExpr(toType, *expr, pos);
546 return true;
547 }
548
CanConvertTypes(const Type * fromType,const Type * toType,const char * errorMsgBase,SourcePos pos)549 bool ispc::CanConvertTypes(const Type *fromType, const Type *toType, const char *errorMsgBase, SourcePos pos) {
550 return lDoTypeConv(fromType, toType, NULL, errorMsgBase == NULL, errorMsgBase, pos);
551 }
552
TypeConvertExpr(Expr * expr,const Type * toType,const char * errorMsgBase)553 Expr *ispc::TypeConvertExpr(Expr *expr, const Type *toType, const char *errorMsgBase) {
554 if (expr == NULL)
555 return NULL;
556
557 #if 0
558 Debug(expr->pos, "type convert %s -> %s.", expr->GetType()->GetString().c_str(),
559 toType->GetString().c_str());
560 #endif
561
562 const Type *fromType = expr->GetType();
563 Expr *e = expr;
564 if (lDoTypeConv(fromType, toType, &e, false, errorMsgBase, expr->pos))
565 return e;
566 else
567 return NULL;
568 }
569
PossiblyResolveFunctionOverloads(Expr * expr,const Type * type)570 bool ispc::PossiblyResolveFunctionOverloads(Expr *expr, const Type *type) {
571 FunctionSymbolExpr *fse = NULL;
572 const FunctionType *funcType = NULL;
573 if (CastType<PointerType>(type) != NULL && (funcType = CastType<FunctionType>(type->GetBaseType())) &&
574 (fse = llvm::dyn_cast<FunctionSymbolExpr>(expr)) != NULL) {
575 // We're initializing a function pointer with a function symbol,
576 // which in turn may represent an overloaded function. So we need
577 // to try to resolve the overload based on the type of the symbol
578 // we're initializing here.
579 std::vector<const Type *> paramTypes;
580 for (int i = 0; i < funcType->GetNumParameters(); ++i)
581 paramTypes.push_back(funcType->GetParameterType(i));
582
583 if (fse->ResolveOverloads(expr->pos, paramTypes) == false)
584 return false;
585 }
586 return true;
587 }
588
589 /** Utility routine that emits code to initialize a symbol given an
590 initializer expression.
591
592 @param ptr Memory location of storage for the symbol's data
593 @param symName Name of symbol (used in error messages)
594 @param symType Type of variable being initialized
595 @param initExpr Expression for the initializer
596 @param ctx FunctionEmitContext to use for generating instructions
597 @param pos Source file position of the variable being initialized
598 */
InitSymbol(llvm::Value * ptr,const Type * symType,Expr * initExpr,FunctionEmitContext * ctx,SourcePos pos)599 void ispc::InitSymbol(llvm::Value *ptr, const Type *symType, Expr *initExpr, FunctionEmitContext *ctx, SourcePos pos) {
600 if (initExpr == NULL)
601 // leave it uninitialized
602 return;
603
604 // See if we have a constant initializer a this point
605 std::pair<llvm::Constant *, bool> constValPair = initExpr->GetStorageConstant(symType);
606 llvm::Constant *constValue = constValPair.first;
607 if (constValue != NULL) {
608 // It'd be nice if we could just do a StoreInst(constValue, ptr)
609 // at this point, but unfortunately that doesn't generate great
610 // code (e.g. a bunch of scalar moves for a constant array.) So
611 // instead we'll make a constant static global that holds the
612 // constant value and emit a memcpy to put its value into the
613 // pointer we have.
614 llvm::Type *llvmType = symType->LLVMStorageType(g->ctx);
615 if (llvmType == NULL) {
616 AssertPos(pos, m->errorCount > 0);
617 return;
618 }
619
620 if (Type::IsBasicType(symType))
621 ctx->StoreInst(constValue, ptr, symType, symType->IsUniformType());
622 else {
623 llvm::Value *constPtr =
624 new llvm::GlobalVariable(*m->module, llvmType, true /* const */, llvm::GlobalValue::InternalLinkage,
625 constValue, "const_initializer");
626 llvm::Value *size = g->target->SizeOf(llvmType, ctx->GetCurrentBasicBlock());
627 ctx->MemcpyInst(ptr, constPtr, size);
628 }
629
630 return;
631 }
632
633 // If the initializer is a straight up expression that isn't an
634 // ExprList, then we'll see if we can type convert it to the type of
635 // the variable.
636 if (llvm::dyn_cast<ExprList>(initExpr) == NULL) {
637 if (PossiblyResolveFunctionOverloads(initExpr, symType) == false)
638 return;
639 initExpr = TypeConvertExpr(initExpr, symType, "initializer");
640
641 if (initExpr == NULL)
642 return;
643
644 llvm::Value *initializerValue = initExpr->GetValue(ctx);
645 if (initializerValue != NULL)
646 // Bingo; store the value in the variable's storage
647 ctx->StoreInst(initializerValue, ptr, symType, symType->IsUniformType());
648 return;
649 }
650
651 // Atomic types and enums can be initialized with { ... } initializer
652 // expressions if they have a single element (except for SOA types,
653 // which are handled below).
654 if (symType->IsSOAType() == false && Type::IsBasicType(symType)) {
655 ExprList *elist = llvm::dyn_cast<ExprList>(initExpr);
656 if (elist != NULL) {
657 if (elist->exprs.size() == 1) {
658 InitSymbol(ptr, symType, elist->exprs[0], ctx, pos);
659 return;
660 } else if (symType->IsVaryingType() == false) {
661 Error(initExpr->pos,
662 "Expression list initializers with "
663 "multiple values can't be used with type \"%s\".",
664 symType->GetString().c_str());
665 return;
666 }
667 } else
668 return;
669 }
670
671 const ReferenceType *rt = CastType<ReferenceType>(symType);
672 if (rt) {
673 if (!Type::Equal(initExpr->GetType(), rt)) {
674 Error(initExpr->pos,
675 "Initializer for reference type \"%s\" must have same "
676 "reference type itself. \"%s\" is incompatible.",
677 rt->GetString().c_str(), initExpr->GetType()->GetString().c_str());
678 return;
679 }
680
681 llvm::Value *initializerValue = initExpr->GetValue(ctx);
682 if (initializerValue)
683 ctx->StoreInst(initializerValue, ptr, initExpr->GetType(), symType->IsUniformType());
684 return;
685 }
686
687 // Handle initiailizers for SOA types as well as for structs, arrays,
688 // and vectors.
689 const CollectionType *collectionType = CastType<CollectionType>(symType);
690 if (collectionType != NULL || symType->IsSOAType() ||
691 (Type::IsBasicType(symType) && symType->IsVaryingType() == true)) {
692 // Make default value equivalent to number of elements for varying
693 int nElements = g->target->getVectorWidth();
694 if (collectionType)
695 nElements = collectionType->GetElementCount();
696 else if (symType->IsSOAType())
697 nElements = symType->GetSOAWidth();
698
699 std::string name;
700 if (CastType<StructType>(symType) != NULL)
701 name = "struct";
702 else if (CastType<ArrayType>(symType) != NULL)
703 name = "array";
704 else if (CastType<VectorType>(symType) != NULL)
705 name = "vector";
706 else if (symType->IsSOAType() || (Type::IsBasicType(symType) && symType->IsVaryingType() == true))
707 name = symType->GetVariability().GetString();
708 else
709 FATAL("Unexpected CollectionType in InitSymbol()");
710
711 // There are two cases for initializing these types; either a
712 // single initializer may be provided (float foo[3] = 0;), in which
713 // case all of the elements are initialized to the given value, or
714 // an initializer list may be provided (float foo[3] = { 1,2,3 }),
715 // in which case the elements are initialized with the
716 // corresponding values.
717 ExprList *exprList = llvm::dyn_cast<ExprList>(initExpr);
718 if (exprList != NULL) {
719 // The { ... } case; make sure we have the no more expressions
720 // in the ExprList as we have struct members
721 int nInits = exprList->exprs.size();
722 if (nInits > nElements) {
723 Error(initExpr->pos,
724 "Initializer for %s type \"%s\" requires "
725 "no more than %d values; %d provided.",
726 name.c_str(), symType->GetString().c_str(), nElements, nInits);
727 return;
728 } else if ((Type::IsBasicType(symType) && symType->IsVaryingType() == true) && (nInits < nElements)) {
729 Error(initExpr->pos,
730 "Initializer for %s type \"%s\" requires "
731 "%d values; %d provided.",
732 name.c_str(), symType->GetString().c_str(), nElements, nInits);
733 return;
734 }
735
736 // Initialize each element with the corresponding value from
737 // the ExprList
738 for (int i = 0; i < nElements; ++i) {
739 // For SOA types and varying, the element type is the uniform variant
740 // of the underlying type
741 const Type *elementType =
742 collectionType ? collectionType->GetElementType(i) : symType->GetAsUniformType();
743
744 llvm::Value *ep;
745 if (CastType<StructType>(symType) != NULL)
746 ep = ctx->AddElementOffset(ptr, i, NULL, "element");
747 else
748 ep = ctx->GetElementPtrInst(ptr, LLVMInt32(0), LLVMInt32(i), PointerType::GetUniform(elementType),
749 "gep");
750
751 if (i < nInits)
752 InitSymbol(ep, elementType, exprList->exprs[i], ctx, pos);
753 else {
754 // If we don't have enough initializer values, initialize the
755 // rest as zero.
756 llvm::Type *llvmType = elementType->LLVMStorageType(g->ctx);
757 if (llvmType == NULL) {
758 AssertPos(pos, m->errorCount > 0);
759 return;
760 }
761
762 llvm::Constant *zeroInit = llvm::Constant::getNullValue(llvmType);
763 ctx->StoreInst(zeroInit, ep, elementType, elementType->IsUniformType());
764 }
765 }
766 } else if (collectionType) {
767 Error(initExpr->pos, "Can't assign type \"%s\" to \"%s\".", initExpr->GetType()->GetString().c_str(),
768 collectionType->GetString().c_str());
769 } else {
770 FATAL("CollectionType is NULL in InitSymbol()");
771 }
772 return;
773 }
774
775 FATAL("Unexpected Type in InitSymbol()");
776 }
777
778 ///////////////////////////////////////////////////////////////////////////
779
780 /** Given an atomic or vector type, this returns a boolean type with the
781 same "shape". In other words, if the given type is a vector type of
782 three uniform ints, the returned type is a vector type of three uniform
783 bools. */
lMatchingBoolType(const Type * type)784 static const Type *lMatchingBoolType(const Type *type) {
785 bool uniformTest = type->IsUniformType();
786 const AtomicType *boolBase = uniformTest ? AtomicType::UniformBool : AtomicType::VaryingBool;
787 const VectorType *vt = CastType<VectorType>(type);
788 if (vt != NULL)
789 return new VectorType(boolBase, vt->GetElementCount());
790 else {
791 Assert(Type::IsBasicType(type) || type->IsReferenceType());
792 return boolBase;
793 }
794 }
795
796 ///////////////////////////////////////////////////////////////////////////
797 // UnaryExpr
798
lLLVMConstantValue(const Type * type,llvm::LLVMContext * ctx,double value)799 static llvm::Constant *lLLVMConstantValue(const Type *type, llvm::LLVMContext *ctx, double value) {
800 const AtomicType *atomicType = CastType<AtomicType>(type);
801 const EnumType *enumType = CastType<EnumType>(type);
802 const VectorType *vectorType = CastType<VectorType>(type);
803 const PointerType *pointerType = CastType<PointerType>(type);
804
805 // This function is only called with, and only works for atomic, enum,
806 // and vector types.
807 Assert(atomicType != NULL || enumType != NULL || vectorType != NULL || pointerType != NULL);
808
809 if (atomicType != NULL || enumType != NULL) {
810 // If it's an atomic or enuemrator type, then figure out which of
811 // the llvmutil.h functions to call to get the corresponding
812 // constant and then call it...
813 bool isUniform = type->IsUniformType();
814 AtomicType::BasicType basicType = (enumType != NULL) ? AtomicType::TYPE_UINT32 : atomicType->basicType;
815
816 switch (basicType) {
817 case AtomicType::TYPE_VOID:
818 FATAL("can't get constant value for void type");
819 return NULL;
820 case AtomicType::TYPE_BOOL:
821 if (isUniform)
822 return (value != 0.) ? LLVMTrue : LLVMFalse;
823 else
824 return LLVMBoolVector(value != 0.);
825 case AtomicType::TYPE_INT8: {
826 int i = (int)value;
827 Assert((double)i == value);
828 return isUniform ? LLVMInt8(i) : LLVMInt8Vector(i);
829 }
830 case AtomicType::TYPE_UINT8: {
831 unsigned int i = (unsigned int)value;
832 return isUniform ? LLVMUInt8(i) : LLVMUInt8Vector(i);
833 }
834 case AtomicType::TYPE_INT16: {
835 int i = (int)value;
836 Assert((double)i == value);
837 return isUniform ? LLVMInt16(i) : LLVMInt16Vector(i);
838 }
839 case AtomicType::TYPE_UINT16: {
840 unsigned int i = (unsigned int)value;
841 return isUniform ? LLVMUInt16(i) : LLVMUInt16Vector(i);
842 }
843 case AtomicType::TYPE_INT32: {
844 int i = (int)value;
845 Assert((double)i == value);
846 return isUniform ? LLVMInt32(i) : LLVMInt32Vector(i);
847 }
848 case AtomicType::TYPE_UINT32: {
849 unsigned int i = (unsigned int)value;
850 return isUniform ? LLVMUInt32(i) : LLVMUInt32Vector(i);
851 }
852 case AtomicType::TYPE_FLOAT:
853 return isUniform ? LLVMFloat((float)value) : LLVMFloatVector((float)value);
854 case AtomicType::TYPE_UINT64: {
855 uint64_t i = (uint64_t)value;
856 Assert(value == (int64_t)i);
857 return isUniform ? LLVMUInt64(i) : LLVMUInt64Vector(i);
858 }
859 case AtomicType::TYPE_INT64: {
860 int64_t i = (int64_t)value;
861 Assert((double)i == value);
862 return isUniform ? LLVMInt64(i) : LLVMInt64Vector(i);
863 }
864 case AtomicType::TYPE_DOUBLE:
865 return isUniform ? LLVMDouble(value) : LLVMDoubleVector(value);
866 default:
867 FATAL("logic error in lLLVMConstantValue");
868 return NULL;
869 }
870 } else if (pointerType != NULL) {
871 Assert(value == 0);
872 if (pointerType->IsUniformType())
873 return llvm::Constant::getNullValue(LLVMTypes::VoidPointerType);
874 else
875 return llvm::Constant::getNullValue(LLVMTypes::VoidPointerVectorType);
876 } else {
877 // For vector types, first get the LLVM constant for the basetype with
878 // a recursive call to lLLVMConstantValue().
879 const Type *baseType = vectorType->GetBaseType();
880 llvm::Constant *constElement = lLLVMConstantValue(baseType, ctx, value);
881 llvm::Type *llvmVectorType = vectorType->LLVMType(ctx);
882
883 // Now create a constant version of the corresponding LLVM type that we
884 // use to represent the VectorType.
885 // FIXME: this is a little ugly in that the fact that ispc represents
886 // uniform VectorTypes as LLVM VectorTypes and varying VectorTypes as
887 // LLVM ArrayTypes leaks into the code here; it feels like this detail
888 // should be better encapsulated?
889 if (baseType->IsUniformType()) {
890 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
891 llvm::FixedVectorType *lvt = llvm::dyn_cast<llvm::FixedVectorType>(llvmVectorType);
892 #else
893 llvm::VectorType *lvt = llvm::dyn_cast<llvm::VectorType>(llvmVectorType);
894 #endif
895 Assert(lvt != NULL);
896 std::vector<llvm::Constant *> vals;
897 for (unsigned int i = 0; i < lvt->getNumElements(); ++i)
898 vals.push_back(constElement);
899 return llvm::ConstantVector::get(vals);
900 } else {
901 llvm::ArrayType *lat = llvm::dyn_cast<llvm::ArrayType>(llvmVectorType);
902 Assert(lat != NULL);
903 std::vector<llvm::Constant *> vals;
904 for (unsigned int i = 0; i < lat->getNumElements(); ++i)
905 vals.push_back(constElement);
906 return llvm::ConstantArray::get(lat, vals);
907 }
908 }
909 }
910
lMaskForSymbol(Symbol * baseSym,FunctionEmitContext * ctx)911 static llvm::Value *lMaskForSymbol(Symbol *baseSym, FunctionEmitContext *ctx) {
912 if (baseSym == NULL)
913 return ctx->GetFullMask();
914
915 if (CastType<PointerType>(baseSym->type) != NULL || CastType<ReferenceType>(baseSym->type) != NULL)
916 // FIXME: for pointers, we really only want to do this for
917 // dereferencing the pointer, not for things like pointer
918 // arithmetic, when we may be able to use the internal mask,
919 // depending on context...
920 return ctx->GetFullMask();
921
922 llvm::Value *mask = (baseSym->parentFunction == ctx->GetFunction() && baseSym->storageClass != SC_STATIC)
923 ? ctx->GetInternalMask()
924 : ctx->GetFullMask();
925 return mask;
926 }
927
928 /** Store the result of an assignment to the given location.
929 */
lStoreAssignResult(llvm::Value * value,llvm::Value * ptr,const Type * valueType,const Type * ptrType,FunctionEmitContext * ctx,Symbol * baseSym)930 static void lStoreAssignResult(llvm::Value *value, llvm::Value *ptr, const Type *valueType, const Type *ptrType,
931 FunctionEmitContext *ctx, Symbol *baseSym) {
932 Assert(baseSym == NULL || baseSym->varyingCFDepth <= ctx->VaryingCFDepth());
933 if (!g->opt.disableMaskedStoreToStore && !g->opt.disableMaskAllOnOptimizations && baseSym != NULL &&
934 baseSym->varyingCFDepth == ctx->VaryingCFDepth() && baseSym->storageClass != SC_STATIC &&
935 CastType<ReferenceType>(baseSym->type) == NULL && CastType<PointerType>(baseSym->type) == NULL) {
936 // If the variable is declared at the same varying control flow
937 // depth as where it's being assigned, then we don't need to do any
938 // masking but can just do the assignment as if all the lanes were
939 // known to be on. While this may lead to random/garbage values
940 // written into the lanes that are off, by definition they will
941 // never be accessed, since those lanes aren't executing, and won't
942 // be executing at this scope or any other one before the variable
943 // goes out of scope.
944 ctx->StoreInst(value, ptr, LLVMMaskAllOn, valueType, ptrType);
945 } else {
946 ctx->StoreInst(value, ptr, lMaskForSymbol(baseSym, ctx), valueType, ptrType);
947 }
948 }
949
950 /** Utility routine to emit code to do a {pre,post}-{inc,dec}rement of the
951 given expresion.
952 */
lEmitPrePostIncDec(UnaryExpr::Op op,Expr * expr,SourcePos pos,FunctionEmitContext * ctx)953 static llvm::Value *lEmitPrePostIncDec(UnaryExpr::Op op, Expr *expr, SourcePos pos, FunctionEmitContext *ctx) {
954 const Type *type = expr->GetType();
955 if (type == NULL)
956 return NULL;
957
958 // Get both the lvalue and the rvalue of the given expression
959 llvm::Value *lvalue = NULL, *rvalue = NULL;
960 const Type *lvalueType = NULL;
961 if (CastType<ReferenceType>(type) != NULL) {
962 lvalueType = type;
963 type = type->GetReferenceTarget();
964 lvalue = expr->GetValue(ctx);
965
966 RefDerefExpr *deref = new RefDerefExpr(expr, expr->pos);
967 rvalue = deref->GetValue(ctx);
968 } else {
969 lvalue = expr->GetLValue(ctx);
970 lvalueType = expr->GetLValueType();
971 rvalue = expr->GetValue(ctx);
972 }
973
974 if (lvalue == NULL) {
975 // If we can't get a lvalue, then we have an error here
976 const char *prepost = (op == UnaryExpr::PreInc || op == UnaryExpr::PreDec) ? "pre" : "post";
977 const char *incdec = (op == UnaryExpr::PreInc || op == UnaryExpr::PostInc) ? "increment" : "decrement";
978 Error(pos, "Can't %s-%s non-lvalues.", prepost, incdec);
979 return NULL;
980 }
981
982 // Emit code to do the appropriate addition/subtraction to the
983 // expression's old value
984 ctx->SetDebugPos(pos);
985 llvm::Value *binop = NULL;
986 int delta = (op == UnaryExpr::PreInc || op == UnaryExpr::PostInc) ? 1 : -1;
987
988 std::string opName = rvalue->getName().str();
989 if (op == UnaryExpr::PreInc || op == UnaryExpr::PostInc)
990 opName += "_plus1";
991 else
992 opName += "_minus1";
993
994 if (CastType<PointerType>(type) != NULL) {
995 const Type *incType = type->IsUniformType() ? AtomicType::UniformInt32 : AtomicType::VaryingInt32;
996 llvm::Constant *dval = lLLVMConstantValue(incType, g->ctx, delta);
997 binop = ctx->GetElementPtrInst(rvalue, dval, type, opName.c_str());
998 } else {
999 llvm::Constant *dval = lLLVMConstantValue(type, g->ctx, delta);
1000 if (type->IsFloatType())
1001 binop = ctx->BinaryOperator(llvm::Instruction::FAdd, rvalue, dval, opName.c_str());
1002 else
1003 binop = ctx->BinaryOperator(llvm::Instruction::Add, rvalue, dval, opName.c_str());
1004 }
1005
1006 // And store the result out to the lvalue
1007 Symbol *baseSym = expr->GetBaseSymbol();
1008 lStoreAssignResult(binop, lvalue, type, lvalueType, ctx, baseSym);
1009
1010 // And then if it's a pre increment/decrement, return the final
1011 // computed result; otherwise return the previously-grabbed expression
1012 // value.
1013 return (op == UnaryExpr::PreInc || op == UnaryExpr::PreDec) ? binop : rvalue;
1014 }
1015
1016 /** Utility routine to emit code to negate the given expression.
1017 */
lEmitNegate(Expr * arg,SourcePos pos,FunctionEmitContext * ctx)1018 static llvm::Value *lEmitNegate(Expr *arg, SourcePos pos, FunctionEmitContext *ctx) {
1019 const Type *type = arg->GetType();
1020 llvm::Value *argVal = arg->GetValue(ctx);
1021 if (type == NULL || argVal == NULL)
1022 return NULL;
1023
1024 // Negate by subtracting from zero...
1025 ctx->SetDebugPos(pos);
1026 if (type->IsFloatType()) {
1027 llvm::Value *zero = llvm::ConstantFP::getZeroValueForNegation(type->LLVMType(g->ctx));
1028 return ctx->BinaryOperator(llvm::Instruction::FSub, zero, argVal, llvm::Twine(argVal->getName()) + "_negate");
1029 } else {
1030 llvm::Value *zero = lLLVMConstantValue(type, g->ctx, 0.);
1031 AssertPos(pos, type->IsIntType());
1032 return ctx->BinaryOperator(llvm::Instruction::Sub, zero, argVal, llvm::Twine(argVal->getName()) + "_negate");
1033 }
1034 }
1035
UnaryExpr(Op o,Expr * e,SourcePos p)1036 UnaryExpr::UnaryExpr(Op o, Expr *e, SourcePos p) : Expr(p, UnaryExprID), op(o) { expr = e; }
1037
GetValue(FunctionEmitContext * ctx) const1038 llvm::Value *UnaryExpr::GetValue(FunctionEmitContext *ctx) const {
1039 if (expr == NULL)
1040 return NULL;
1041
1042 ctx->SetDebugPos(pos);
1043
1044 switch (op) {
1045 case PreInc:
1046 case PreDec:
1047 case PostInc:
1048 case PostDec:
1049 return lEmitPrePostIncDec(op, expr, pos, ctx);
1050 case Negate:
1051 return lEmitNegate(expr, pos, ctx);
1052 case LogicalNot: {
1053 llvm::Value *argVal = expr->GetValue(ctx);
1054 return ctx->NotOperator(argVal, llvm::Twine(argVal->getName()) + "_logicalnot");
1055 }
1056 case BitNot: {
1057 llvm::Value *argVal = expr->GetValue(ctx);
1058 return ctx->NotOperator(argVal, llvm::Twine(argVal->getName()) + "_bitnot");
1059 }
1060 default:
1061 FATAL("logic error");
1062 return NULL;
1063 }
1064 }
1065
GetType() const1066 const Type *UnaryExpr::GetType() const {
1067 if (expr == NULL)
1068 return NULL;
1069
1070 const Type *type = expr->GetType();
1071 if (type == NULL)
1072 return NULL;
1073
1074 // For all unary expressions besides logical not, the returned type is
1075 // the same as the source type. Logical not always returns a bool
1076 // type, with the same shape as the input type.
1077 switch (op) {
1078 case PreInc:
1079 case PreDec:
1080 case PostInc:
1081 case PostDec:
1082 case Negate:
1083 case BitNot:
1084 return type;
1085 case LogicalNot:
1086 return lMatchingBoolType(type);
1087 default:
1088 FATAL("error");
1089 return NULL;
1090 }
1091 }
1092
lOptimizeBitNot(ConstExpr * constExpr,const Type * type,SourcePos pos)1093 template <typename T> static Expr *lOptimizeBitNot(ConstExpr *constExpr, const Type *type, SourcePos pos) {
1094 T v[ISPC_MAX_NVEC];
1095 int count = constExpr->GetValues(v);
1096 for (int i = 0; i < count; ++i)
1097 v[i] = ~v[i];
1098 return new ConstExpr(type, v, pos);
1099 }
1100
lOptimizeNegate(ConstExpr * constExpr,const Type * type,SourcePos pos)1101 template <typename T> static Expr *lOptimizeNegate(ConstExpr *constExpr, const Type *type, SourcePos pos) {
1102 T v[ISPC_MAX_NVEC];
1103 int count = constExpr->GetValues(v);
1104 for (int i = 0; i < count; ++i)
1105 v[i] = -v[i];
1106 return new ConstExpr(type, v, pos);
1107 }
1108
Optimize()1109 Expr *UnaryExpr::Optimize() {
1110 ConstExpr *constExpr = llvm::dyn_cast<ConstExpr>(expr);
1111 // If the operand isn't a constant, then we can't do any optimization
1112 // here...
1113 if (constExpr == NULL)
1114 return this;
1115
1116 const Type *type = constExpr->GetType();
1117 bool isEnumType = CastType<EnumType>(type) != NULL;
1118
1119 switch (op) {
1120 case PreInc:
1121 case PreDec:
1122 case PostInc:
1123 case PostDec:
1124 // this shouldn't happen--it's illegal to modify a contant value..
1125 // An error will be issued elsewhere...
1126 return this;
1127 case Negate: {
1128 if (Type::EqualIgnoringConst(type, AtomicType::UniformInt64) ||
1129 Type::EqualIgnoringConst(type, AtomicType::VaryingInt64)) {
1130 return lOptimizeNegate<int64_t>(constExpr, type, pos);
1131 } else if (Type::EqualIgnoringConst(type, AtomicType::UniformUInt64) ||
1132 Type::EqualIgnoringConst(type, AtomicType::VaryingUInt64)) {
1133 return lOptimizeNegate<uint64_t>(constExpr, type, pos);
1134 } else if (Type::EqualIgnoringConst(type, AtomicType::UniformInt32) ||
1135 Type::EqualIgnoringConst(type, AtomicType::VaryingInt32)) {
1136 return lOptimizeNegate<int32_t>(constExpr, type, pos);
1137 } else if (Type::EqualIgnoringConst(type, AtomicType::UniformUInt32) ||
1138 Type::EqualIgnoringConst(type, AtomicType::VaryingUInt32)) {
1139 return lOptimizeNegate<uint32_t>(constExpr, type, pos);
1140 } else if (Type::EqualIgnoringConst(type, AtomicType::UniformInt16) ||
1141 Type::EqualIgnoringConst(type, AtomicType::VaryingInt16)) {
1142 return lOptimizeNegate<int16_t>(constExpr, type, pos);
1143 } else if (Type::EqualIgnoringConst(type, AtomicType::UniformUInt16) ||
1144 Type::EqualIgnoringConst(type, AtomicType::VaryingUInt32)) {
1145 return lOptimizeNegate<uint16_t>(constExpr, type, pos);
1146 } else if (Type::EqualIgnoringConst(type, AtomicType::UniformInt8) ||
1147 Type::EqualIgnoringConst(type, AtomicType::VaryingInt8)) {
1148 return lOptimizeNegate<int8_t>(constExpr, type, pos);
1149 } else if (Type::EqualIgnoringConst(type, AtomicType::UniformUInt8) ||
1150 Type::EqualIgnoringConst(type, AtomicType::VaryingUInt8)) {
1151 return lOptimizeNegate<uint8_t>(constExpr, type, pos);
1152 } else {
1153 // For all the other types, it's safe to stuff whatever we have
1154 // into a double, do the negate as a double, and then return a
1155 // ConstExpr with the same type as the original...
1156 double v[ISPC_MAX_NVEC];
1157 int count = constExpr->GetValues(v);
1158 for (int i = 0; i < count; ++i)
1159 v[i] = -v[i];
1160 return new ConstExpr(constExpr, v);
1161 }
1162 }
1163 case BitNot: {
1164 if (Type::EqualIgnoringConst(type, AtomicType::UniformInt8) ||
1165 Type::EqualIgnoringConst(type, AtomicType::VaryingInt8)) {
1166 return lOptimizeBitNot<int8_t>(constExpr, type, pos);
1167 } else if (Type::EqualIgnoringConst(type, AtomicType::UniformUInt8) ||
1168 Type::EqualIgnoringConst(type, AtomicType::VaryingUInt8)) {
1169 return lOptimizeBitNot<uint8_t>(constExpr, type, pos);
1170 } else if (Type::EqualIgnoringConst(type, AtomicType::UniformInt16) ||
1171 Type::EqualIgnoringConst(type, AtomicType::VaryingInt16)) {
1172 return lOptimizeBitNot<int16_t>(constExpr, type, pos);
1173 } else if (Type::EqualIgnoringConst(type, AtomicType::UniformUInt16) ||
1174 Type::EqualIgnoringConst(type, AtomicType::VaryingUInt16)) {
1175 return lOptimizeBitNot<uint16_t>(constExpr, type, pos);
1176 } else if (Type::EqualIgnoringConst(type, AtomicType::UniformInt32) ||
1177 Type::EqualIgnoringConst(type, AtomicType::VaryingInt32)) {
1178 return lOptimizeBitNot<int32_t>(constExpr, type, pos);
1179 } else if (Type::EqualIgnoringConst(type, AtomicType::UniformUInt32) ||
1180 Type::EqualIgnoringConst(type, AtomicType::VaryingUInt32) || isEnumType == true) {
1181 return lOptimizeBitNot<uint32_t>(constExpr, type, pos);
1182 } else if (Type::EqualIgnoringConst(type, AtomicType::UniformInt64) ||
1183 Type::EqualIgnoringConst(type, AtomicType::VaryingInt64)) {
1184 return lOptimizeBitNot<int64_t>(constExpr, type, pos);
1185 } else if (Type::EqualIgnoringConst(type, AtomicType::UniformUInt64) ||
1186 Type::EqualIgnoringConst(type, AtomicType::VaryingUInt64)) {
1187 return lOptimizeBitNot<uint64_t>(constExpr, type, pos);
1188 } else
1189 FATAL("unexpected type in UnaryExpr::Optimize() / BitNot case");
1190 }
1191 case LogicalNot: {
1192 AssertPos(pos, Type::EqualIgnoringConst(type, AtomicType::UniformBool) ||
1193 Type::EqualIgnoringConst(type, AtomicType::VaryingBool));
1194 bool v[ISPC_MAX_NVEC];
1195 int count = constExpr->GetValues(v);
1196 for (int i = 0; i < count; ++i)
1197 v[i] = !v[i];
1198 return new ConstExpr(type, v, pos);
1199 }
1200 default:
1201 FATAL("unexpected op in UnaryExpr::Optimize()");
1202 return NULL;
1203 }
1204 }
1205
TypeCheck()1206 Expr *UnaryExpr::TypeCheck() {
1207 const Type *type;
1208 if (expr == NULL || (type = expr->GetType()) == NULL)
1209 // something went wrong in type checking...
1210 return NULL;
1211
1212 if (type->IsSOAType()) {
1213 Error(pos, "Can't apply unary operator to SOA type \"%s\".", type->GetString().c_str());
1214 return NULL;
1215 }
1216
1217 if (op == PreInc || op == PreDec || op == PostInc || op == PostDec) {
1218 if (type->IsConstType()) {
1219 Error(pos,
1220 "Can't assign to type \"%s\" on left-hand side of "
1221 "expression.",
1222 type->GetString().c_str());
1223 return NULL;
1224 }
1225
1226 if (type->IsNumericType())
1227 return this;
1228
1229 const PointerType *pt = CastType<PointerType>(type);
1230 if (pt == NULL) {
1231 Error(expr->pos,
1232 "Can only pre/post increment numeric and "
1233 "pointer types, not \"%s\".",
1234 type->GetString().c_str());
1235 return NULL;
1236 }
1237
1238 if (PointerType::IsVoidPointer(type)) {
1239 Error(expr->pos, "Illegal to pre/post increment \"%s\" type.", type->GetString().c_str());
1240 return NULL;
1241 }
1242 if (CastType<UndefinedStructType>(pt->GetBaseType())) {
1243 Error(expr->pos,
1244 "Illegal to pre/post increment pointer to "
1245 "undefined struct type \"%s\".",
1246 type->GetString().c_str());
1247 return NULL;
1248 }
1249
1250 return this;
1251 }
1252
1253 // don't do this for pre/post increment/decrement
1254 if (CastType<ReferenceType>(type)) {
1255 expr = new RefDerefExpr(expr, pos);
1256 type = expr->GetType();
1257 }
1258
1259 if (op == Negate) {
1260 if (!type->IsNumericType()) {
1261 Error(expr->pos, "Negate not allowed for non-numeric type \"%s\".", type->GetString().c_str());
1262 return NULL;
1263 }
1264 } else if (op == LogicalNot) {
1265 const Type *boolType = lMatchingBoolType(type);
1266 expr = TypeConvertExpr(expr, boolType, "logical not");
1267 if (expr == NULL)
1268 return NULL;
1269 } else if (op == BitNot) {
1270 if (!type->IsIntType()) {
1271 Error(expr->pos,
1272 "~ operator can only be used with integer types, "
1273 "not \"%s\".",
1274 type->GetString().c_str());
1275 return NULL;
1276 }
1277 }
1278 return this;
1279 }
1280
EstimateCost() const1281 int UnaryExpr::EstimateCost() const {
1282 if (llvm::dyn_cast<ConstExpr>(expr) != NULL)
1283 return 0;
1284
1285 return COST_SIMPLE_ARITH_LOGIC_OP;
1286 }
1287
Print() const1288 void UnaryExpr::Print() const {
1289 if (!expr || !GetType())
1290 return;
1291
1292 printf("[ %s ] (", GetType()->GetString().c_str());
1293 if (op == PreInc)
1294 printf("++");
1295 if (op == PreDec)
1296 printf("--");
1297 if (op == Negate)
1298 printf("-");
1299 if (op == LogicalNot)
1300 printf("!");
1301 if (op == BitNot)
1302 printf("~");
1303 printf("(");
1304 expr->Print();
1305 printf(")");
1306 if (op == PostInc)
1307 printf("++");
1308 if (op == PostDec)
1309 printf("--");
1310 printf(")");
1311 pos.Print();
1312 }
1313
1314 ///////////////////////////////////////////////////////////////////////////
1315 // BinaryExpr
1316
lOpString(BinaryExpr::Op op)1317 static const char *lOpString(BinaryExpr::Op op) {
1318 switch (op) {
1319 case BinaryExpr::Add:
1320 return "+";
1321 case BinaryExpr::Sub:
1322 return "-";
1323 case BinaryExpr::Mul:
1324 return "*";
1325 case BinaryExpr::Div:
1326 return "/";
1327 case BinaryExpr::Mod:
1328 return "%";
1329 case BinaryExpr::Shl:
1330 return "<<";
1331 case BinaryExpr::Shr:
1332 return ">>";
1333 case BinaryExpr::Lt:
1334 return "<";
1335 case BinaryExpr::Gt:
1336 return ">";
1337 case BinaryExpr::Le:
1338 return "<=";
1339 case BinaryExpr::Ge:
1340 return ">=";
1341 case BinaryExpr::Equal:
1342 return "==";
1343 case BinaryExpr::NotEqual:
1344 return "!=";
1345 case BinaryExpr::BitAnd:
1346 return "&";
1347 case BinaryExpr::BitXor:
1348 return "^";
1349 case BinaryExpr::BitOr:
1350 return "|";
1351 case BinaryExpr::LogicalAnd:
1352 return "&&";
1353 case BinaryExpr::LogicalOr:
1354 return "||";
1355 case BinaryExpr::Comma:
1356 return ",";
1357 default:
1358 FATAL("unimplemented case in lOpString()");
1359 return "";
1360 }
1361 }
1362
1363 /** Utility routine to emit the binary bitwise operator corresponding to
1364 the given BinaryExpr::Op.
1365 */
lEmitBinaryBitOp(BinaryExpr::Op op,llvm::Value * arg0Val,llvm::Value * arg1Val,bool isUnsigned,FunctionEmitContext * ctx)1366 static llvm::Value *lEmitBinaryBitOp(BinaryExpr::Op op, llvm::Value *arg0Val, llvm::Value *arg1Val, bool isUnsigned,
1367 FunctionEmitContext *ctx) {
1368 llvm::Instruction::BinaryOps inst;
1369 switch (op) {
1370 case BinaryExpr::Shl:
1371 inst = llvm::Instruction::Shl;
1372 break;
1373 case BinaryExpr::Shr:
1374 if (isUnsigned)
1375 inst = llvm::Instruction::LShr;
1376 else
1377 inst = llvm::Instruction::AShr;
1378 break;
1379 case BinaryExpr::BitAnd:
1380 inst = llvm::Instruction::And;
1381 break;
1382 case BinaryExpr::BitXor:
1383 inst = llvm::Instruction::Xor;
1384 break;
1385 case BinaryExpr::BitOr:
1386 inst = llvm::Instruction::Or;
1387 break;
1388 default:
1389 FATAL("logic error in lEmitBinaryBitOp()");
1390 return NULL;
1391 }
1392
1393 return ctx->BinaryOperator(inst, arg0Val, arg1Val, "bitop");
1394 }
1395
lEmitBinaryPointerArith(BinaryExpr::Op op,llvm::Value * value0,llvm::Value * value1,const Type * type0,const Type * type1,FunctionEmitContext * ctx,SourcePos pos)1396 static llvm::Value *lEmitBinaryPointerArith(BinaryExpr::Op op, llvm::Value *value0, llvm::Value *value1,
1397 const Type *type0, const Type *type1, FunctionEmitContext *ctx,
1398 SourcePos pos) {
1399 const PointerType *ptrType = CastType<PointerType>(type0);
1400 AssertPos(pos, ptrType != NULL);
1401 switch (op) {
1402 case BinaryExpr::Add:
1403 // ptr + integer
1404 return ctx->GetElementPtrInst(value0, value1, ptrType, "ptrmath");
1405 break;
1406 case BinaryExpr::Sub: {
1407 if (CastType<PointerType>(type1) != NULL) {
1408 AssertPos(pos, Type::EqualIgnoringConst(type0, type1));
1409
1410 if (ptrType->IsSlice()) {
1411 llvm::Value *p0 = ctx->ExtractInst(value0, 0);
1412 llvm::Value *p1 = ctx->ExtractInst(value1, 0);
1413 const Type *majorType = ptrType->GetAsNonSlice();
1414 llvm::Value *majorDelta = lEmitBinaryPointerArith(op, p0, p1, majorType, majorType, ctx, pos);
1415
1416 int soaWidth = ptrType->GetBaseType()->GetSOAWidth();
1417 AssertPos(pos, soaWidth > 0);
1418 llvm::Value *soaScale = LLVMIntAsType(soaWidth, majorDelta->getType());
1419
1420 llvm::Value *majorScale =
1421 ctx->BinaryOperator(llvm::Instruction::Mul, majorDelta, soaScale, "major_soa_scaled");
1422
1423 llvm::Value *m0 = ctx->ExtractInst(value0, 1);
1424 llvm::Value *m1 = ctx->ExtractInst(value1, 1);
1425 llvm::Value *minorDelta = ctx->BinaryOperator(llvm::Instruction::Sub, m0, m1, "minor_soa_delta");
1426
1427 ctx->MatchIntegerTypes(&majorScale, &minorDelta);
1428 return ctx->BinaryOperator(llvm::Instruction::Add, majorScale, minorDelta, "soa_ptrdiff");
1429 }
1430
1431 // ptr - ptr
1432 if (ptrType->IsUniformType()) {
1433 value0 = ctx->PtrToIntInst(value0);
1434 value1 = ctx->PtrToIntInst(value1);
1435 }
1436
1437 // Compute the difference in bytes
1438 llvm::Value *delta = ctx->BinaryOperator(llvm::Instruction::Sub, value0, value1, "ptr_diff");
1439
1440 // Now divide by the size of the type that the pointer
1441 // points to in order to return the difference in elements.
1442 llvm::Type *llvmElementType = ptrType->GetBaseType()->LLVMType(g->ctx);
1443 llvm::Value *size = g->target->SizeOf(llvmElementType, ctx->GetCurrentBasicBlock());
1444 if (ptrType->IsVaryingType())
1445 size = ctx->SmearUniform(size);
1446
1447 if (g->target->is32Bit() == false && g->opt.force32BitAddressing == true) {
1448 // If we're doing 32-bit addressing math on a 64-bit
1449 // target, then trunc the delta down to a 32-bit value.
1450 // (Thus also matching what will be a 32-bit value
1451 // returned from SizeOf above.)
1452 if (ptrType->IsUniformType())
1453 delta = ctx->TruncInst(delta, LLVMTypes::Int32Type, "trunc_ptr_delta");
1454 else
1455 delta = ctx->TruncInst(delta, LLVMTypes::Int32VectorType, "trunc_ptr_delta");
1456 }
1457
1458 // And now do the actual division
1459 return ctx->BinaryOperator(llvm::Instruction::SDiv, delta, size, "element_diff");
1460 } else {
1461 // ptr - integer
1462 llvm::Value *zero = lLLVMConstantValue(type1, g->ctx, 0.);
1463 llvm::Value *negOffset = ctx->BinaryOperator(llvm::Instruction::Sub, zero, value1, "negate");
1464 // Do a GEP as ptr + -integer
1465 return ctx->GetElementPtrInst(value0, negOffset, ptrType, "ptrmath");
1466 }
1467 }
1468 default:
1469 FATAL("Logic error in lEmitBinaryArith() for pointer type case");
1470 return NULL;
1471 }
1472 }
1473
1474 /** Utility routine to emit binary arithmetic operator based on the given
1475 BinaryExpr::Op.
1476 */
lEmitBinaryArith(BinaryExpr::Op op,llvm::Value * value0,llvm::Value * value1,const Type * type0,const Type * type1,FunctionEmitContext * ctx,SourcePos pos)1477 static llvm::Value *lEmitBinaryArith(BinaryExpr::Op op, llvm::Value *value0, llvm::Value *value1, const Type *type0,
1478 const Type *type1, FunctionEmitContext *ctx, SourcePos pos) {
1479 const PointerType *ptrType = CastType<PointerType>(type0);
1480
1481 if (ptrType != NULL)
1482 return lEmitBinaryPointerArith(op, value0, value1, type0, type1, ctx, pos);
1483 else {
1484 AssertPos(pos, Type::EqualIgnoringConst(type0, type1));
1485
1486 llvm::Instruction::BinaryOps inst;
1487 bool isFloatOp = type0->IsFloatType();
1488 bool isUnsignedOp = type0->IsUnsignedType();
1489
1490 const char *opName = NULL;
1491 switch (op) {
1492 case BinaryExpr::Add:
1493 opName = "add";
1494 inst = isFloatOp ? llvm::Instruction::FAdd : llvm::Instruction::Add;
1495 break;
1496 case BinaryExpr::Sub:
1497 opName = "sub";
1498 inst = isFloatOp ? llvm::Instruction::FSub : llvm::Instruction::Sub;
1499 break;
1500 case BinaryExpr::Mul:
1501 opName = "mul";
1502 inst = isFloatOp ? llvm::Instruction::FMul : llvm::Instruction::Mul;
1503 break;
1504 case BinaryExpr::Div:
1505 opName = "div";
1506 if (type0->IsVaryingType() && !isFloatOp)
1507 PerformanceWarning(pos, "Division with varying integer types is "
1508 "very inefficient.");
1509 inst = isFloatOp ? llvm::Instruction::FDiv
1510 : (isUnsignedOp ? llvm::Instruction::UDiv : llvm::Instruction::SDiv);
1511 break;
1512 case BinaryExpr::Mod:
1513 opName = "mod";
1514 if (type0->IsVaryingType() && !isFloatOp)
1515 PerformanceWarning(pos, "Modulus operator with varying types is "
1516 "very inefficient.");
1517 inst = isFloatOp ? llvm::Instruction::FRem
1518 : (isUnsignedOp ? llvm::Instruction::URem : llvm::Instruction::SRem);
1519 break;
1520 default:
1521 FATAL("Invalid op type passed to lEmitBinaryArith()");
1522 return NULL;
1523 }
1524
1525 return ctx->BinaryOperator(inst, value0, value1,
1526 (((llvm::Twine(opName) + "_") + value0->getName()) + "_") + value1->getName());
1527 }
1528 }
1529
1530 /** Utility routine to emit a binary comparison operator based on the given
1531 BinaryExpr::Op.
1532 */
lEmitBinaryCmp(BinaryExpr::Op op,llvm::Value * e0Val,llvm::Value * e1Val,const Type * type,FunctionEmitContext * ctx,SourcePos pos)1533 static llvm::Value *lEmitBinaryCmp(BinaryExpr::Op op, llvm::Value *e0Val, llvm::Value *e1Val, const Type *type,
1534 FunctionEmitContext *ctx, SourcePos pos) {
1535 bool isFloatOp = type->IsFloatType();
1536 bool isUnsignedOp = type->IsUnsignedType();
1537
1538 llvm::CmpInst::Predicate pred;
1539 const char *opName = NULL;
1540 switch (op) {
1541 case BinaryExpr::Lt:
1542 opName = "less";
1543 pred = isFloatOp ? llvm::CmpInst::FCMP_OLT : (isUnsignedOp ? llvm::CmpInst::ICMP_ULT : llvm::CmpInst::ICMP_SLT);
1544 break;
1545 case BinaryExpr::Gt:
1546 opName = "greater";
1547 pred = isFloatOp ? llvm::CmpInst::FCMP_OGT : (isUnsignedOp ? llvm::CmpInst::ICMP_UGT : llvm::CmpInst::ICMP_SGT);
1548 break;
1549 case BinaryExpr::Le:
1550 opName = "lessequal";
1551 pred = isFloatOp ? llvm::CmpInst::FCMP_OLE : (isUnsignedOp ? llvm::CmpInst::ICMP_ULE : llvm::CmpInst::ICMP_SLE);
1552 break;
1553 case BinaryExpr::Ge:
1554 opName = "greaterequal";
1555 pred = isFloatOp ? llvm::CmpInst::FCMP_OGE : (isUnsignedOp ? llvm::CmpInst::ICMP_UGE : llvm::CmpInst::ICMP_SGE);
1556 break;
1557 case BinaryExpr::Equal:
1558 opName = "equal";
1559 pred = isFloatOp ? llvm::CmpInst::FCMP_OEQ : llvm::CmpInst::ICMP_EQ;
1560 break;
1561 case BinaryExpr::NotEqual:
1562 opName = "notequal";
1563 pred = isFloatOp ? llvm::CmpInst::FCMP_UNE : llvm::CmpInst::ICMP_NE;
1564 break;
1565 default:
1566 FATAL("error in lEmitBinaryCmp()");
1567 return NULL;
1568 }
1569
1570 llvm::Value *cmp = ctx->CmpInst(isFloatOp ? llvm::Instruction::FCmp : llvm::Instruction::ICmp, pred, e0Val, e1Val,
1571 (((llvm::Twine(opName) + "_") + e0Val->getName()) + "_") + e1Val->getName());
1572 // This is a little ugly: CmpInst returns i1 values, but we use vectors
1573 // of i32s for varying bool values; type convert the result here if
1574 // needed.
1575 if (type->IsVaryingType())
1576 cmp = ctx->I1VecToBoolVec(cmp);
1577
1578 return cmp;
1579 }
1580
BinaryExpr(Op o,Expr * a,Expr * b,SourcePos p)1581 BinaryExpr::BinaryExpr(Op o, Expr *a, Expr *b, SourcePos p) : Expr(p, BinaryExprID), op(o) {
1582 arg0 = a;
1583 arg1 = b;
1584 }
1585
lCreateBinaryOperatorCall(const BinaryExpr::Op bop,Expr * a0,Expr * a1,Expr * & op,const SourcePos & sp)1586 bool lCreateBinaryOperatorCall(const BinaryExpr::Op bop, Expr *a0, Expr *a1, Expr *&op, const SourcePos &sp) {
1587 bool abort = false;
1588 if ((a0 == NULL) || (a1 == NULL)) {
1589 return abort;
1590 }
1591 Expr *arg0 = a0;
1592 Expr *arg1 = a1;
1593 const Type *type0 = arg0->GetType();
1594 const Type *type1 = arg1->GetType();
1595
1596 // If either operand is a reference, dereference it before we move
1597 // forward
1598 if (CastType<ReferenceType>(type0) != NULL) {
1599 arg0 = new RefDerefExpr(arg0, arg0->pos);
1600 type0 = arg0->GetType();
1601 }
1602 if (CastType<ReferenceType>(type1) != NULL) {
1603 arg1 = new RefDerefExpr(arg1, arg1->pos);
1604 type1 = arg1->GetType();
1605 }
1606 if ((type0 == NULL) || (type1 == NULL)) {
1607 return abort;
1608 }
1609 if (CastType<StructType>(type0) != NULL || CastType<StructType>(type1) != NULL) {
1610 std::string opName = std::string("operator") + lOpString(bop);
1611 std::vector<Symbol *> funs;
1612 m->symbolTable->LookupFunction(opName.c_str(), &funs);
1613 if (funs.size() == 0) {
1614 Error(sp, "operator %s(%s, %s) is not defined.", opName.c_str(), (type0->GetString()).c_str(),
1615 (type1->GetString()).c_str());
1616 abort = true;
1617 return abort;
1618 }
1619 Expr *func = new FunctionSymbolExpr(opName.c_str(), funs, sp);
1620 ExprList *args = new ExprList(sp);
1621 args->exprs.push_back(arg0);
1622 args->exprs.push_back(arg1);
1623 op = new FunctionCallExpr(func, args, sp);
1624 return abort;
1625 }
1626 return abort;
1627 }
1628
MakeBinaryExpr(BinaryExpr::Op o,Expr * a,Expr * b,SourcePos p)1629 Expr *ispc::MakeBinaryExpr(BinaryExpr::Op o, Expr *a, Expr *b, SourcePos p) {
1630 Expr *op = NULL;
1631 bool abort = lCreateBinaryOperatorCall(o, a, b, op, p);
1632 if (op != NULL) {
1633 return op;
1634 }
1635
1636 // lCreateBinaryOperatorCall can return NULL for 2 cases:
1637 // 1. When there is an error.
1638 // 2. We have to create a new BinaryExpr.
1639 if (abort) {
1640 AssertPos(p, m->errorCount > 0);
1641 return NULL;
1642 }
1643
1644 op = new BinaryExpr(o, a, b, p);
1645 return op;
1646 }
1647
1648 /** Emit code for a && or || logical operator. In particular, the code
1649 here handles "short-circuit" evaluation, where the second expression
1650 isn't evaluated if the value of the first one determines the value of
1651 the result.
1652 */
lEmitLogicalOp(BinaryExpr::Op op,Expr * arg0,Expr * arg1,FunctionEmitContext * ctx,SourcePos pos)1653 llvm::Value *lEmitLogicalOp(BinaryExpr::Op op, Expr *arg0, Expr *arg1, FunctionEmitContext *ctx, SourcePos pos) {
1654
1655 const Type *type0 = arg0->GetType(), *type1 = arg1->GetType();
1656 if (type0 == NULL || type1 == NULL) {
1657 AssertPos(pos, m->errorCount > 0);
1658 return NULL;
1659 }
1660
1661 // There is overhead (branches, etc.), to short-circuiting, so if the
1662 // right side of the expression is a) relatively simple, and b) can be
1663 // safely executed with an all-off execution mask, then we just
1664 // evaluate both sides and then the logical operator in that case.
1665 int threshold =
1666 g->target->isGenXTarget() ? PREDICATE_SAFE_SHORT_CIRC_GENX_STATEMENT_COST : PREDICATE_SAFE_IF_STATEMENT_COST;
1667 bool shortCircuit = (EstimateCost(arg1) > threshold || SafeToRunWithMaskAllOff(arg1) == false);
1668
1669 // Skip short-circuiting for VectorTypes as well.
1670 if ((shortCircuit == false) || CastType<VectorType>(type0) != NULL || CastType<VectorType>(type1) != NULL) {
1671 // If one of the operands is uniform but the other is varying,
1672 // promote the uniform one to varying
1673 if (type0->IsUniformType() && type1->IsVaryingType()) {
1674 arg0 = TypeConvertExpr(arg0, AtomicType::VaryingBool, lOpString(op));
1675 AssertPos(pos, arg0 != NULL);
1676 }
1677 if (type1->IsUniformType() && type0->IsVaryingType()) {
1678 arg1 = TypeConvertExpr(arg1, AtomicType::VaryingBool, lOpString(op));
1679 AssertPos(pos, arg1 != NULL);
1680 }
1681
1682 llvm::Value *value0 = arg0->GetValue(ctx);
1683 llvm::Value *value1 = arg1->GetValue(ctx);
1684 if (value0 == NULL || value1 == NULL) {
1685 AssertPos(pos, m->errorCount > 0);
1686 return NULL;
1687 }
1688
1689 if (op == BinaryExpr::LogicalAnd)
1690 return ctx->BinaryOperator(llvm::Instruction::And, value0, value1, "logical_and");
1691 else {
1692 AssertPos(pos, op == BinaryExpr::LogicalOr);
1693 return ctx->BinaryOperator(llvm::Instruction::Or, value0, value1, "logical_or");
1694 }
1695 }
1696
1697 // Allocate temporary storage for the return value
1698 const Type *retType = Type::MoreGeneralType(type0, type1, pos, lOpString(op));
1699 if (retType == NULL) {
1700 AssertPos(pos, m->errorCount > 0);
1701 return NULL;
1702 }
1703 llvm::Value *retPtr = ctx->AllocaInst(retType, "logical_op_mem");
1704 llvm::BasicBlock *bbSkipEvalValue1 = ctx->CreateBasicBlock("skip_eval_1", ctx->GetCurrentBasicBlock());
1705 llvm::BasicBlock *bbEvalValue1 = ctx->CreateBasicBlock("eval_1", bbSkipEvalValue1);
1706 llvm::BasicBlock *bbLogicalDone = ctx->CreateBasicBlock("logical_op_done", bbEvalValue1);
1707
1708 // Evaluate the first operand
1709 llvm::Value *value0 = arg0->GetValue(ctx);
1710 if (value0 == NULL) {
1711 AssertPos(pos, m->errorCount > 0);
1712 return NULL;
1713 }
1714
1715 if (type0->IsUniformType()) {
1716 // Check to see if the value of the first operand is true or false
1717 llvm::Value *value0True = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ, value0, LLVMTrue);
1718
1719 if (op == BinaryExpr::LogicalOr) {
1720 // For ||, if value0 is true, then we skip evaluating value1
1721 // entirely.
1722 ctx->BranchInst(bbSkipEvalValue1, bbEvalValue1, value0True);
1723
1724 // If value0 is true, the complete result is true (either
1725 // uniform or varying)
1726 ctx->SetCurrentBasicBlock(bbSkipEvalValue1);
1727 llvm::Value *trueValue = retType->IsUniformType() ? LLVMTrue : LLVMMaskAllOn;
1728 ctx->StoreInst(trueValue, retPtr, retType, retType->IsUniformType());
1729 ctx->BranchInst(bbLogicalDone);
1730 } else {
1731 AssertPos(pos, op == BinaryExpr::LogicalAnd);
1732
1733 // Conversely, for &&, if value0 is false, we skip evaluating
1734 // value1.
1735 ctx->BranchInst(bbEvalValue1, bbSkipEvalValue1, value0True);
1736
1737 // In this case, the complete result is false (again, either a
1738 // uniform or varying false).
1739 ctx->SetCurrentBasicBlock(bbSkipEvalValue1);
1740 llvm::Value *falseValue = retType->IsUniformType() ? LLVMFalse : LLVMMaskAllOff;
1741 ctx->StoreInst(falseValue, retPtr, retType, retType->IsUniformType());
1742 ctx->BranchInst(bbLogicalDone);
1743 }
1744
1745 // Both || and && are in the same situation if the first operand's
1746 // value didn't resolve the final result: they need to evaluate the
1747 // value of the second operand, which in turn gives the value for
1748 // the full expression.
1749 ctx->SetCurrentBasicBlock(bbEvalValue1);
1750 if (type1->IsUniformType() && retType->IsVaryingType()) {
1751 arg1 = TypeConvertExpr(arg1, AtomicType::VaryingBool, "logical op");
1752 AssertPos(pos, arg1 != NULL);
1753 }
1754
1755 llvm::Value *value1 = arg1->GetValue(ctx);
1756 if (value1 == NULL) {
1757 AssertPos(pos, m->errorCount > 0);
1758 return NULL;
1759 }
1760 ctx->StoreInst(value1, retPtr, arg1->GetType(), retType->IsUniformType());
1761 ctx->BranchInst(bbLogicalDone);
1762
1763 // In all cases, we end up at the bbLogicalDone basic block;
1764 // loading the value stored in retPtr in turn gives the overall
1765 // result.
1766 ctx->SetCurrentBasicBlock(bbLogicalDone);
1767 return ctx->LoadInst(retPtr, retType);
1768 } else {
1769 // Otherwise, the first operand is varying... Save the current
1770 // value of the mask so that we can restore it at the end.
1771 llvm::Value *oldMask = ctx->GetInternalMask();
1772 llvm::Value *oldFullMask = ctx->GetFullMask();
1773
1774 // Convert the second operand to be varying as well, so that we can
1775 // perform logical vector ops with its value.
1776 if (type1->IsUniformType()) {
1777 arg1 = TypeConvertExpr(arg1, AtomicType::VaryingBool, "logical op");
1778 AssertPos(pos, arg1 != NULL);
1779 type1 = arg1->GetType();
1780 }
1781
1782 if (op == BinaryExpr::LogicalOr) {
1783 // See if value0 is true for all currently executing
1784 // lanes--i.e. if (value0 & mask) == mask. If so, we don't
1785 // need to evaluate the second operand of the expression.
1786 llvm::Value *value0AndMask = ctx->BinaryOperator(llvm::Instruction::And, value0, oldFullMask, "op&mask");
1787 llvm::Value *equalsMask = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ, value0AndMask,
1788 oldFullMask, "value0&mask==mask");
1789 equalsMask = ctx->I1VecToBoolVec(equalsMask);
1790 if (!ctx->emitGenXHardwareMask()) {
1791 llvm::Value *allMatch = ctx->All(equalsMask);
1792 ctx->BranchInst(bbSkipEvalValue1, bbEvalValue1, allMatch);
1793 } else {
1794 // If uniform CF is emulated, pass vector value to BranchInst
1795 ctx->BranchInst(bbSkipEvalValue1, bbEvalValue1, equalsMask);
1796 }
1797
1798 // value0 is true for all running lanes, so it can be used for
1799 // the final result
1800 ctx->SetCurrentBasicBlock(bbSkipEvalValue1);
1801 ctx->StoreInst(value0, retPtr, arg0->GetType(), retType->IsUniformType());
1802 ctx->BranchInst(bbLogicalDone);
1803
1804 // Otherwise, we need to valuate arg1. However, first we need
1805 // to set the execution mask to be (oldMask & ~a); in other
1806 // words, only execute the instances where value0 is false.
1807 // For the instances where value0 was true, we need to inhibit
1808 // execution.
1809 ctx->SetCurrentBasicBlock(bbEvalValue1);
1810 ctx->SetInternalMaskAndNot(oldMask, value0);
1811
1812 llvm::Value *value1 = arg1->GetValue(ctx);
1813 if (value1 == NULL) {
1814 AssertPos(pos, m->errorCount > 0);
1815 return NULL;
1816 }
1817
1818 // We need to compute the result carefully, since vector
1819 // elements that were computed when the corresponding lane was
1820 // disabled have undefined values:
1821 // result = (value0 & old_mask) | (value1 & current_mask)
1822 llvm::Value *value1AndMask =
1823 ctx->BinaryOperator(llvm::Instruction::And, value1, ctx->GetInternalMask(), "op&mask");
1824 llvm::Value *result = ctx->BinaryOperator(llvm::Instruction::Or, value0AndMask, value1AndMask, "or_result");
1825 ctx->StoreInst(result, retPtr, retType, retType->IsUniformType());
1826 ctx->BranchInst(bbLogicalDone);
1827 } else {
1828 AssertPos(pos, op == BinaryExpr::LogicalAnd);
1829
1830 // If value0 is false for all currently running lanes, the
1831 // overall result must be false: this corresponds to checking
1832 // if (mask & ~value0) == mask.
1833 llvm::Value *notValue0 = ctx->NotOperator(value0, "not_value0");
1834 llvm::Value *notValue0AndMask =
1835 ctx->BinaryOperator(llvm::Instruction::And, notValue0, oldFullMask, "not_value0&mask");
1836 llvm::Value *equalsMask = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ, notValue0AndMask,
1837 oldFullMask, "not_value0&mask==mask");
1838 equalsMask = ctx->I1VecToBoolVec(equalsMask);
1839 if (!ctx->emitGenXHardwareMask()) {
1840 llvm::Value *allMatch = ctx->All(equalsMask);
1841 ctx->BranchInst(bbSkipEvalValue1, bbEvalValue1, allMatch);
1842 } else {
1843 // If uniform CF is emulated, pass vector value to BranchInst
1844 ctx->BranchInst(bbSkipEvalValue1, bbEvalValue1, equalsMask);
1845 }
1846
1847 // value0 was false for all running lanes, so use its value as
1848 // the overall result.
1849 ctx->SetCurrentBasicBlock(bbSkipEvalValue1);
1850 ctx->StoreInst(value0, retPtr, arg0->GetType(), retType->IsUniformType());
1851 ctx->BranchInst(bbLogicalDone);
1852
1853 // Otherwise we need to evaluate value1, but again with the
1854 // mask set to only be on for the lanes where value0 was true.
1855 // For the lanes where value0 was false, execution needs to be
1856 // disabled: mask = (mask & value0).
1857 ctx->SetCurrentBasicBlock(bbEvalValue1);
1858 ctx->SetInternalMaskAnd(oldMask, value0);
1859
1860 llvm::Value *value1 = arg1->GetValue(ctx);
1861 if (value1 == NULL) {
1862 AssertPos(pos, m->errorCount > 0);
1863 return NULL;
1864 }
1865
1866 // And as in the || case, we compute the overall result by
1867 // masking off the valid lanes before we AND them together:
1868 // result = (value0 & old_mask) & (value1 & current_mask)
1869 llvm::Value *value0AndMask = ctx->BinaryOperator(llvm::Instruction::And, value0, oldFullMask, "op&mask");
1870 llvm::Value *value1AndMask =
1871 ctx->BinaryOperator(llvm::Instruction::And, value1, ctx->GetInternalMask(), "value1&mask");
1872 llvm::Value *result =
1873 ctx->BinaryOperator(llvm::Instruction::And, value0AndMask, value1AndMask, "or_result");
1874 ctx->StoreInst(result, retPtr, retType, retType->IsUniformType());
1875 ctx->BranchInst(bbLogicalDone);
1876 }
1877
1878 // And finally we always end up in bbLogicalDone, where we restore
1879 // the old mask and return the computed result
1880 ctx->SetCurrentBasicBlock(bbLogicalDone);
1881 ctx->SetInternalMask(oldMask);
1882 return ctx->LoadInst(retPtr, retType);
1883 }
1884 }
1885
1886 /* Returns true if shifting right by the given amount will lead to
1887 inefficient code. (Assumes x86 target. May also warn inaccurately if
1888 later optimization simplify the shift amount more than we are able to
1889 see at this point.) */
lIsDifficultShiftAmount(Expr * expr)1890 static bool lIsDifficultShiftAmount(Expr *expr) {
1891 // Uniform shifts (of uniform values) are no problem.
1892 if (expr->GetType()->IsVaryingType() == false)
1893 return false;
1894
1895 ConstExpr *ce = llvm::dyn_cast<ConstExpr>(expr);
1896 if (ce) {
1897 // If the shift is by a constant amount, *and* it's the same amount
1898 // in all vector lanes, we're in good shape.
1899 uint32_t amount[ISPC_MAX_NVEC];
1900 int count = ce->GetValues(amount);
1901 for (int i = 1; i < count; ++i)
1902 if (amount[i] != amount[0])
1903 return true;
1904 return false;
1905 }
1906
1907 TypeCastExpr *tce = llvm::dyn_cast<TypeCastExpr>(expr);
1908 if (tce && tce->expr) {
1909 // Finally, if the shift amount is given by a uniform value that's
1910 // been smeared out into a varying, we have the same shift for all
1911 // lanes and are also in good shape.
1912 return (tce->expr->GetType()->IsUniformType() == false);
1913 }
1914
1915 return true;
1916 }
1917
HasAmbiguousVariability(std::vector<const Expr * > & warn) const1918 bool BinaryExpr::HasAmbiguousVariability(std::vector<const Expr *> &warn) const {
1919 bool isArg0Amb = false;
1920 bool isArg1Amb = false;
1921 if (arg0 != NULL) {
1922 const Type *type0 = arg0->GetType();
1923 if (arg0->HasAmbiguousVariability(warn)) {
1924 isArg0Amb = true;
1925 } else if ((type0 != NULL) && (type0->IsVaryingType())) {
1926 // If either arg is varying, then the expression is un-ambiguously varying.
1927 return false;
1928 }
1929 }
1930 if (arg1 != NULL) {
1931 const Type *type1 = arg1->GetType();
1932 if (arg1->HasAmbiguousVariability(warn)) {
1933 isArg1Amb = true;
1934 } else if ((type1 != NULL) && (type1->IsVaryingType())) {
1935 // If either arg is varying, then the expression is un-ambiguously varying.
1936 return false;
1937 }
1938 }
1939 if (isArg0Amb || isArg1Amb) {
1940 return true;
1941 }
1942
1943 return false;
1944 }
1945
GetValue(FunctionEmitContext * ctx) const1946 llvm::Value *BinaryExpr::GetValue(FunctionEmitContext *ctx) const {
1947 if (!arg0 || !arg1) {
1948 AssertPos(pos, m->errorCount > 0);
1949 return NULL;
1950 }
1951
1952 // Handle these specially, since we want to short-circuit their evaluation...
1953 if (op == LogicalAnd || op == LogicalOr)
1954 return lEmitLogicalOp(op, arg0, arg1, ctx, pos);
1955
1956 llvm::Value *value0 = arg0->GetValue(ctx);
1957 llvm::Value *value1 = arg1->GetValue(ctx);
1958 if (value0 == NULL || value1 == NULL) {
1959 AssertPos(pos, m->errorCount > 0);
1960 return NULL;
1961 }
1962
1963 ctx->SetDebugPos(pos);
1964
1965 switch (op) {
1966 case Add:
1967 case Sub:
1968 case Mul:
1969 case Div:
1970 case Mod:
1971 return lEmitBinaryArith(op, value0, value1, arg0->GetType(), arg1->GetType(), ctx, pos);
1972 case Lt:
1973 case Gt:
1974 case Le:
1975 case Ge:
1976 case Equal:
1977 case NotEqual:
1978 return lEmitBinaryCmp(op, value0, value1, arg0->GetType(), ctx, pos);
1979 case Shl:
1980 case Shr:
1981 case BitAnd:
1982 case BitXor:
1983 case BitOr: {
1984 if (op == Shr && lIsDifficultShiftAmount(arg1))
1985 PerformanceWarning(pos, "Shift right is inefficient for "
1986 "varying shift amounts.");
1987 return lEmitBinaryBitOp(op, value0, value1, arg0->GetType()->IsUnsignedType(), ctx);
1988 }
1989 case Comma:
1990 return value1;
1991 default:
1992 FATAL("logic error");
1993 return NULL;
1994 }
1995 }
1996
GetType() const1997 const Type *BinaryExpr::GetType() const {
1998 if (arg0 == NULL || arg1 == NULL)
1999 return NULL;
2000
2001 const Type *type0 = arg0->GetType(), *type1 = arg1->GetType();
2002 if (type0 == NULL || type1 == NULL)
2003 return NULL;
2004
2005 // If this hits, it means that our TypeCheck() method hasn't been
2006 // called before GetType() was called; adding two pointers is illegal
2007 // and will fail type checking and (int + ptr) should be canonicalized
2008 // into (ptr + int) by type checking.
2009 if (op == Add)
2010 AssertPos(pos, CastType<PointerType>(type1) == NULL);
2011
2012 if (op == Comma)
2013 return arg1->GetType();
2014
2015 if (CastType<PointerType>(type0) != NULL) {
2016 if (op == Add)
2017 // ptr + int -> ptr
2018 return type0;
2019 else if (op == Sub) {
2020 if (CastType<PointerType>(type1) != NULL) {
2021 // ptr - ptr -> ~ptrdiff_t
2022 const Type *diffType = (g->target->is32Bit() || g->opt.force32BitAddressing) ? AtomicType::UniformInt32
2023 : AtomicType::UniformInt64;
2024 if (type0->IsVaryingType() || type1->IsVaryingType())
2025 diffType = diffType->GetAsVaryingType();
2026 return diffType;
2027 } else
2028 // ptr - int -> ptr
2029 return type0;
2030 }
2031
2032 // otherwise fall through for these...
2033 AssertPos(pos, op == Lt || op == Gt || op == Le || op == Ge || op == Equal || op == NotEqual);
2034 }
2035
2036 const Type *exprType = Type::MoreGeneralType(type0, type1, pos, lOpString(op));
2037 // I don't think that MoreGeneralType should be able to fail after the
2038 // checks done in BinaryExpr::TypeCheck().
2039 AssertPos(pos, exprType != NULL);
2040
2041 switch (op) {
2042 case Add:
2043 case Sub:
2044 case Mul:
2045 case Div:
2046 case Mod:
2047 return exprType;
2048 case Lt:
2049 case Gt:
2050 case Le:
2051 case Ge:
2052 case Equal:
2053 case NotEqual:
2054 case LogicalAnd:
2055 case LogicalOr:
2056 return lMatchingBoolType(exprType);
2057 case Shl:
2058 case Shr:
2059 return type1->IsVaryingType() ? type0->GetAsVaryingType() : type0;
2060 case BitAnd:
2061 case BitXor:
2062 case BitOr:
2063 return exprType;
2064 case Comma:
2065 // handled above, so fall through here just in case
2066 default:
2067 FATAL("logic error in BinaryExpr::GetType()");
2068 return NULL;
2069 }
2070 }
2071
2072 #define FOLD_OP(O, E) \
2073 case O: \
2074 for (int i = 0; i < count; ++i) \
2075 result[i] = (v0[i] E v1[i]); \
2076 break
2077
2078 #define FOLD_OP_REF(O, E, TRef) \
2079 case O: \
2080 for (int i = 0; i < count; ++i) { \
2081 result[i] = (v0[i] E v1[i]); \
2082 TRef r = (TRef)v0[i] E(TRef) v1[i]; \
2083 if (result[i] != r) \
2084 Warning(pos, "Binary expression with type \"%s\" can't represent value.", \
2085 carg0->GetType()->GetString().c_str()); \
2086 } \
2087 break
2088
countLeadingZeros(T val)2089 template <typename T> static int countLeadingZeros(T val) {
2090
2091 int leadingZeros = 0;
2092 size_t size = sizeof(T) * CHAR_BIT;
2093 T msb = (T)(T(1) << (size - 1));
2094
2095 while (size--) {
2096 if (msb & val) {
2097 break;
2098 }
2099 msb = msb >> 1;
2100 leadingZeros++;
2101 }
2102 return leadingZeros;
2103 }
2104
2105 /** Constant fold the binary integer operations that aren't also applicable
2106 to floating-point types.
2107 */
2108 template <typename T, typename TRef>
lConstFoldBinaryIntOp(BinaryExpr::Op op,const T * v0,const T * v1,ConstExpr * carg0,SourcePos pos)2109 static ConstExpr *lConstFoldBinaryIntOp(BinaryExpr::Op op, const T *v0, const T *v1, ConstExpr *carg0, SourcePos pos) {
2110 T result[ISPC_MAX_NVEC];
2111 int count = carg0->Count();
2112
2113 switch (op) {
2114 FOLD_OP_REF(BinaryExpr::Shr, >>, TRef);
2115 FOLD_OP_REF(BinaryExpr::BitAnd, &, TRef);
2116 FOLD_OP_REF(BinaryExpr::BitXor, ^, TRef);
2117 FOLD_OP_REF(BinaryExpr::BitOr, |, TRef);
2118
2119 case BinaryExpr::Shl:
2120 for (int i = 0; i < count; ++i) {
2121 result[i] = (T(v0[i]) << v1[i]);
2122 if (v1[i] > countLeadingZeros(v0[i])) {
2123 Warning(pos, "Binary expression with type \"%s\" can't represent value.",
2124 carg0->GetType()->GetString().c_str());
2125 }
2126 }
2127 break;
2128 case BinaryExpr::Mod:
2129 for (int i = 0; i < count; ++i) {
2130 if (v1[i] == 0) {
2131 Warning(pos, "Remainder by zero is undefined.");
2132 return NULL;
2133 } else {
2134 result[i] = (v0[i] % v1[i]);
2135 }
2136 }
2137 break;
2138 default:
2139 return NULL;
2140 }
2141
2142 return new ConstExpr(carg0->GetType(), result, carg0->pos);
2143 }
2144
2145 /** Constant fold the binary logical ops.
2146 */
2147 template <typename T>
lConstFoldBinaryLogicalOp(BinaryExpr::Op op,const T * v0,const T * v1,ConstExpr * carg0)2148 static ConstExpr *lConstFoldBinaryLogicalOp(BinaryExpr::Op op, const T *v0, const T *v1, ConstExpr *carg0) {
2149 bool result[ISPC_MAX_NVEC];
2150 int count = carg0->Count();
2151
2152 switch (op) {
2153 FOLD_OP(BinaryExpr::Lt, <);
2154 FOLD_OP(BinaryExpr::Gt, >);
2155 FOLD_OP(BinaryExpr::Le, <=);
2156 FOLD_OP(BinaryExpr::Ge, >=);
2157 FOLD_OP(BinaryExpr::Equal, ==);
2158 FOLD_OP(BinaryExpr::NotEqual, !=);
2159 FOLD_OP(BinaryExpr::LogicalAnd, &&);
2160 FOLD_OP(BinaryExpr::LogicalOr, ||);
2161 default:
2162 return NULL;
2163 }
2164
2165 const Type *rType = carg0->GetType()->IsUniformType() ? AtomicType::UniformBool : AtomicType::VaryingBool;
2166 return new ConstExpr(rType, result, carg0->pos);
2167 }
2168
2169 /** Constant fold binary arithmetic ops.
2170 */
2171 template <typename T, typename TRef>
lConstFoldBinaryArithOp(BinaryExpr::Op op,const T * v0,const T * v1,ConstExpr * carg0,SourcePos pos)2172 static ConstExpr *lConstFoldBinaryArithOp(BinaryExpr::Op op, const T *v0, const T *v1, ConstExpr *carg0,
2173 SourcePos pos) {
2174 T result[ISPC_MAX_NVEC];
2175 int count = carg0->Count();
2176
2177 switch (op) {
2178 FOLD_OP_REF(BinaryExpr::Add, +, TRef);
2179 FOLD_OP_REF(BinaryExpr::Sub, -, TRef);
2180 FOLD_OP_REF(BinaryExpr::Mul, *, TRef);
2181 case BinaryExpr::Div:
2182 for (int i = 0; i < count; ++i) {
2183 if (v1[i] == 0) {
2184 Warning(pos, "Division by zero is undefined.");
2185 return NULL;
2186 } else {
2187 result[i] = (v0[i] / v1[i]);
2188 }
2189 }
2190 break;
2191 default:
2192 return NULL;
2193 }
2194
2195 return new ConstExpr(carg0->GetType(), result, carg0->pos);
2196 }
2197
2198 /** Constant fold the various boolean binary ops.
2199 */
lConstFoldBoolBinaryOp(BinaryExpr::Op op,const bool * v0,const bool * v1,ConstExpr * carg0)2200 static ConstExpr *lConstFoldBoolBinaryOp(BinaryExpr::Op op, const bool *v0, const bool *v1, ConstExpr *carg0) {
2201 bool result[ISPC_MAX_NVEC];
2202 int count = carg0->Count();
2203
2204 switch (op) {
2205 FOLD_OP(BinaryExpr::BitAnd, &);
2206 FOLD_OP(BinaryExpr::BitXor, ^);
2207 FOLD_OP(BinaryExpr::BitOr, |);
2208 FOLD_OP(BinaryExpr::Lt, <);
2209 FOLD_OP(BinaryExpr::Gt, >);
2210 FOLD_OP(BinaryExpr::Le, <=);
2211 FOLD_OP(BinaryExpr::Ge, >=);
2212 FOLD_OP(BinaryExpr::Equal, ==);
2213 FOLD_OP(BinaryExpr::NotEqual, !=);
2214 FOLD_OP(BinaryExpr::LogicalAnd, &&);
2215 FOLD_OP(BinaryExpr::LogicalOr, ||);
2216 default:
2217 return NULL;
2218 }
2219
2220 return new ConstExpr(carg0->GetType(), result, carg0->pos);
2221 }
2222
2223 template <typename T>
lConstFoldBinaryFPOp(ConstExpr * constArg0,ConstExpr * constArg1,BinaryExpr::Op op,BinaryExpr * origExpr,SourcePos pos)2224 static Expr *lConstFoldBinaryFPOp(ConstExpr *constArg0, ConstExpr *constArg1, BinaryExpr::Op op, BinaryExpr *origExpr,
2225 SourcePos pos) {
2226 T v0[ISPC_MAX_NVEC], v1[ISPC_MAX_NVEC];
2227 constArg0->GetValues(v0);
2228 constArg1->GetValues(v1);
2229 ConstExpr *ret;
2230 if ((ret = lConstFoldBinaryArithOp<T, T>(op, v0, v1, constArg0, pos)) != NULL)
2231 return ret;
2232 else if ((ret = lConstFoldBinaryLogicalOp(op, v0, v1, constArg0)) != NULL)
2233 return ret;
2234 else
2235 return origExpr;
2236 }
2237
2238 template <typename T, typename TRef>
lConstFoldBinaryIntOp(ConstExpr * constArg0,ConstExpr * constArg1,BinaryExpr::Op op,BinaryExpr * origExpr,SourcePos pos)2239 static Expr *lConstFoldBinaryIntOp(ConstExpr *constArg0, ConstExpr *constArg1, BinaryExpr::Op op, BinaryExpr *origExpr,
2240 SourcePos pos) {
2241 T v0[ISPC_MAX_NVEC], v1[ISPC_MAX_NVEC];
2242 constArg0->GetValues(v0);
2243 constArg1->GetValues(v1);
2244 ConstExpr *ret;
2245 if ((ret = lConstFoldBinaryArithOp<T, TRef>(op, v0, v1, constArg0, pos)) != NULL)
2246 return ret;
2247 else if ((ret = lConstFoldBinaryIntOp<T, TRef>(op, v0, v1, constArg0, pos)) != NULL)
2248 return ret;
2249 else if ((ret = lConstFoldBinaryLogicalOp(op, v0, v1, constArg0)) != NULL)
2250 return ret;
2251 else
2252 return origExpr;
2253 }
2254
Optimize()2255 Expr *BinaryExpr::Optimize() {
2256 if (arg0 == NULL || arg1 == NULL)
2257 return NULL;
2258
2259 ConstExpr *constArg0 = llvm::dyn_cast<ConstExpr>(arg0);
2260 ConstExpr *constArg1 = llvm::dyn_cast<ConstExpr>(arg1);
2261
2262 if (g->opt.fastMath) {
2263 // optimizations related to division by floats..
2264
2265 // transform x / const -> x * (1/const)
2266 if (op == Div && constArg1 != NULL) {
2267 const Type *type1 = constArg1->GetType();
2268 if (Type::EqualIgnoringConst(type1, AtomicType::UniformFloat) ||
2269 Type::EqualIgnoringConst(type1, AtomicType::VaryingFloat)) {
2270 float inv[ISPC_MAX_NVEC];
2271 int count = constArg1->GetValues(inv);
2272 for (int i = 0; i < count; ++i)
2273 inv[i] = 1.f / inv[i];
2274 Expr *einv = new ConstExpr(type1, inv, constArg1->pos);
2275 Expr *e = new BinaryExpr(Mul, arg0, einv, pos);
2276 e = ::TypeCheck(e);
2277 if (e == NULL)
2278 return NULL;
2279 return ::Optimize(e);
2280 }
2281 }
2282
2283 // transform x / y -> x * rcp(y)
2284 if (op == Div) {
2285 const Type *type1 = arg1->GetType();
2286 if (Type::EqualIgnoringConst(type1, AtomicType::UniformFloat) ||
2287 Type::EqualIgnoringConst(type1, AtomicType::VaryingFloat)) {
2288 // Get the symbol for the appropriate builtin
2289 std::vector<Symbol *> rcpFuns;
2290 m->symbolTable->LookupFunction("rcp", &rcpFuns);
2291 if (rcpFuns.size() > 0) {
2292 Expr *rcpSymExpr = new FunctionSymbolExpr("rcp", rcpFuns, pos);
2293 ExprList *args = new ExprList(arg1, arg1->pos);
2294 Expr *rcpCall = new FunctionCallExpr(rcpSymExpr, args, arg1->pos);
2295 rcpCall = ::TypeCheck(rcpCall);
2296 if (rcpCall == NULL)
2297 return NULL;
2298 rcpCall = ::Optimize(rcpCall);
2299 if (rcpCall == NULL)
2300 return NULL;
2301
2302 Expr *ret = new BinaryExpr(Mul, arg0, rcpCall, pos);
2303 ret = ::TypeCheck(ret);
2304 if (ret == NULL)
2305 return NULL;
2306 return ::Optimize(ret);
2307 } else
2308 Warning(pos, "rcp() not found from stdlib. Can't apply "
2309 "fast-math rcp optimization.");
2310 }
2311 }
2312 }
2313
2314 // From here on out, we're just doing constant folding, so if both args
2315 // aren't constants then we're done...
2316 if (constArg0 == NULL || constArg1 == NULL)
2317 return this;
2318
2319 AssertPos(pos, Type::EqualIgnoringConst(arg0->GetType(), arg1->GetType()));
2320 const Type *type = arg0->GetType()->GetAsNonConstType();
2321 if (Type::Equal(type, AtomicType::UniformFloat) || Type::Equal(type, AtomicType::VaryingFloat)) {
2322 return lConstFoldBinaryFPOp<float>(constArg0, constArg1, op, this, pos);
2323 } else if (Type::Equal(type, AtomicType::UniformDouble) || Type::Equal(type, AtomicType::VaryingDouble)) {
2324 return lConstFoldBinaryFPOp<double>(constArg0, constArg1, op, this, pos);
2325 } else if (Type::Equal(type, AtomicType::UniformInt8) || Type::Equal(type, AtomicType::VaryingInt8)) {
2326 return lConstFoldBinaryIntOp<int8_t, int64_t>(constArg0, constArg1, op, this, pos);
2327 } else if (Type::Equal(type, AtomicType::UniformUInt8) || Type::Equal(type, AtomicType::VaryingUInt8)) {
2328 return lConstFoldBinaryIntOp<uint8_t, uint64_t>(constArg0, constArg1, op, this, pos);
2329 } else if (Type::Equal(type, AtomicType::UniformInt16) || Type::Equal(type, AtomicType::VaryingInt16)) {
2330 return lConstFoldBinaryIntOp<int16_t, int64_t>(constArg0, constArg1, op, this, pos);
2331 } else if (Type::Equal(type, AtomicType::UniformUInt16) || Type::Equal(type, AtomicType::VaryingUInt16)) {
2332 return lConstFoldBinaryIntOp<uint16_t, uint64_t>(constArg0, constArg1, op, this, pos);
2333 } else if (Type::Equal(type, AtomicType::UniformInt32) || Type::Equal(type, AtomicType::VaryingInt32)) {
2334 return lConstFoldBinaryIntOp<int32_t, int64_t>(constArg0, constArg1, op, this, pos);
2335 } else if (Type::Equal(type, AtomicType::UniformUInt32) || Type::Equal(type, AtomicType::VaryingUInt32)) {
2336 return lConstFoldBinaryIntOp<uint32_t, uint64_t>(constArg0, constArg1, op, this, pos);
2337 } else if (Type::Equal(type, AtomicType::UniformInt64) || Type::Equal(type, AtomicType::VaryingInt64)) {
2338 return lConstFoldBinaryIntOp<int64_t, int64_t>(constArg0, constArg1, op, this, pos);
2339 } else if (Type::Equal(type, AtomicType::UniformUInt64) || Type::Equal(type, AtomicType::VaryingUInt64)) {
2340 return lConstFoldBinaryIntOp<uint64_t, uint64_t>(constArg0, constArg1, op, this, pos);
2341 } else if (Type::Equal(type, AtomicType::UniformBool) || Type::Equal(type, AtomicType::VaryingBool)) {
2342 bool v0[ISPC_MAX_NVEC], v1[ISPC_MAX_NVEC];
2343 constArg0->GetValues(v0);
2344 constArg1->GetValues(v1);
2345 ConstExpr *ret;
2346 if ((ret = lConstFoldBoolBinaryOp(op, v0, v1, constArg0)) != NULL)
2347 return ret;
2348 else if ((ret = lConstFoldBinaryLogicalOp(op, v0, v1, constArg0)) != NULL)
2349 return ret;
2350 else
2351 return this;
2352 } else
2353 return this;
2354 }
2355
TypeCheck()2356 Expr *BinaryExpr::TypeCheck() {
2357 if (arg0 == NULL || arg1 == NULL)
2358 return NULL;
2359
2360 const Type *type0 = arg0->GetType(), *type1 = arg1->GetType();
2361 if (type0 == NULL || type1 == NULL)
2362 return NULL;
2363
2364 // If either operand is a reference, dereference it before we move
2365 // forward
2366 if (CastType<ReferenceType>(type0) != NULL) {
2367 arg0 = new RefDerefExpr(arg0, arg0->pos);
2368 type0 = arg0->GetType();
2369 AssertPos(pos, type0 != NULL);
2370 }
2371 if (CastType<ReferenceType>(type1) != NULL) {
2372 arg1 = new RefDerefExpr(arg1, arg1->pos);
2373 type1 = arg1->GetType();
2374 AssertPos(pos, type1 != NULL);
2375 }
2376
2377 // Convert arrays to pointers to their first elements
2378 if (CastType<ArrayType>(type0) != NULL) {
2379 arg0 = lArrayToPointer(arg0);
2380 type0 = arg0->GetType();
2381 }
2382 if (CastType<ArrayType>(type1) != NULL) {
2383 arg1 = lArrayToPointer(arg1);
2384 type1 = arg1->GetType();
2385 }
2386
2387 // Prohibit binary operators with SOA types
2388 if (type0->GetSOAWidth() > 0) {
2389 Error(arg0->pos,
2390 "Illegal to use binary operator %s with SOA type "
2391 "\"%s\".",
2392 lOpString(op), type0->GetString().c_str());
2393 return NULL;
2394 }
2395 if (type1->GetSOAWidth() > 0) {
2396 Error(arg1->pos,
2397 "Illegal to use binary operator %s with SOA type "
2398 "\"%s\".",
2399 lOpString(op), type1->GetString().c_str());
2400 return NULL;
2401 }
2402
2403 const PointerType *pt0 = CastType<PointerType>(type0);
2404 const PointerType *pt1 = CastType<PointerType>(type1);
2405 if (pt0 != NULL && pt1 != NULL && op == Sub) {
2406 // Pointer subtraction
2407 if (PointerType::IsVoidPointer(type0)) {
2408 Error(pos,
2409 "Illegal to perform pointer arithmetic "
2410 "on \"%s\" type.",
2411 type0->GetString().c_str());
2412 return NULL;
2413 }
2414 if (PointerType::IsVoidPointer(type1)) {
2415 Error(pos,
2416 "Illegal to perform pointer arithmetic "
2417 "on \"%s\" type.",
2418 type1->GetString().c_str());
2419 return NULL;
2420 }
2421 if (CastType<UndefinedStructType>(pt0->GetBaseType())) {
2422 Error(pos,
2423 "Illegal to perform pointer arithmetic "
2424 "on undefined struct type \"%s\".",
2425 pt0->GetString().c_str());
2426 return NULL;
2427 }
2428 if (CastType<UndefinedStructType>(pt1->GetBaseType())) {
2429 Error(pos,
2430 "Illegal to perform pointer arithmetic "
2431 "on undefined struct type \"%s\".",
2432 pt1->GetString().c_str());
2433 return NULL;
2434 }
2435
2436 const Type *t = Type::MoreGeneralType(type0, type1, pos, "-");
2437 if (t == NULL)
2438 return NULL;
2439
2440 arg0 = TypeConvertExpr(arg0, t, "pointer subtraction");
2441 arg1 = TypeConvertExpr(arg1, t, "pointer subtraction");
2442 if (arg0 == NULL || arg1 == NULL)
2443 return NULL;
2444
2445 return this;
2446 } else if (((pt0 != NULL || pt1 != NULL) && op == Add) || (pt0 != NULL && op == Sub)) {
2447 // Handle ptr + int, int + ptr, ptr - int
2448 if (pt0 != NULL && pt1 != NULL) {
2449 Error(pos, "Illegal to add two pointer types \"%s\" and \"%s\".", pt0->GetString().c_str(),
2450 pt1->GetString().c_str());
2451 return NULL;
2452 } else if (pt1 != NULL) {
2453 // put in canonical order with the pointer as the first operand
2454 // for GetValue()
2455 std::swap(arg0, arg1);
2456 std::swap(type0, type1);
2457 std::swap(pt0, pt1);
2458 }
2459
2460 AssertPos(pos, pt0 != NULL);
2461
2462 if (PointerType::IsVoidPointer(pt0)) {
2463 Error(pos,
2464 "Illegal to perform pointer arithmetic "
2465 "on \"%s\" type.",
2466 pt0->GetString().c_str());
2467 return NULL;
2468 }
2469 if (CastType<UndefinedStructType>(pt0->GetBaseType())) {
2470 Error(pos,
2471 "Illegal to perform pointer arithmetic "
2472 "on undefined struct type \"%s\".",
2473 pt0->GetString().c_str());
2474 return NULL;
2475 }
2476
2477 const Type *offsetType = g->target->is32Bit() ? AtomicType::UniformInt32 : AtomicType::UniformInt64;
2478 if (pt0->IsVaryingType())
2479 offsetType = offsetType->GetAsVaryingType();
2480 if (type1->IsVaryingType()) {
2481 arg0 = TypeConvertExpr(arg0, type0->GetAsVaryingType(), "pointer addition");
2482 offsetType = offsetType->GetAsVaryingType();
2483 AssertPos(pos, arg0 != NULL);
2484 }
2485
2486 arg1 = TypeConvertExpr(arg1, offsetType, lOpString(op));
2487 if (arg1 == NULL)
2488 return NULL;
2489
2490 return this;
2491 }
2492
2493 switch (op) {
2494 case Shl:
2495 case Shr:
2496 case BitAnd:
2497 case BitXor:
2498 case BitOr: {
2499 // Must have integer or bool-typed operands for these bit-related
2500 // ops; don't do any implicit conversions from floats here...
2501 if (!type0->IsIntType() && !type0->IsBoolType()) {
2502 Error(arg0->pos,
2503 "First operand to binary operator \"%s\" must be "
2504 "an integer or bool.",
2505 lOpString(op));
2506 return NULL;
2507 }
2508 if (!type1->IsIntType() && !type1->IsBoolType()) {
2509 Error(arg1->pos,
2510 "Second operand to binary operator \"%s\" must be "
2511 "an integer or bool.",
2512 lOpString(op));
2513 return NULL;
2514 }
2515
2516 if (op == Shl || op == Shr) {
2517 bool isVarying = (type0->IsVaryingType() || type1->IsVaryingType());
2518 if (isVarying) {
2519 arg0 = TypeConvertExpr(arg0, type0->GetAsVaryingType(), "shift operator");
2520 if (arg0 == NULL)
2521 return NULL;
2522 type0 = arg0->GetType();
2523 }
2524 arg1 = TypeConvertExpr(arg1, type0, "shift operator");
2525 if (arg1 == NULL)
2526 return NULL;
2527 } else {
2528 const Type *promotedType = Type::MoreGeneralType(type0, type1, arg0->pos, "binary bit op");
2529 if (promotedType == NULL)
2530 return NULL;
2531
2532 arg0 = TypeConvertExpr(arg0, promotedType, "binary bit op");
2533 arg1 = TypeConvertExpr(arg1, promotedType, "binary bit op");
2534 if (arg0 == NULL || arg1 == NULL)
2535 return NULL;
2536 }
2537 return this;
2538 }
2539 case Add:
2540 case Sub:
2541 case Mul:
2542 case Div:
2543 case Mod: {
2544 // Must be numeric type for these. (And mod is special--can't be float)
2545 if (!type0->IsNumericType() || (op == Mod && type0->IsFloatType())) {
2546 Error(arg0->pos,
2547 "First operand to binary operator \"%s\" is of "
2548 "invalid type \"%s\".",
2549 lOpString(op), type0->GetString().c_str());
2550 return NULL;
2551 }
2552 if (!type1->IsNumericType() || (op == Mod && type1->IsFloatType())) {
2553 Error(arg1->pos,
2554 "First operand to binary operator \"%s\" is of "
2555 "invalid type \"%s\".",
2556 lOpString(op), type1->GetString().c_str());
2557 return NULL;
2558 }
2559
2560 const Type *promotedType = Type::MoreGeneralType(type0, type1, Union(arg0->pos, arg1->pos), lOpString(op));
2561 if (promotedType == NULL)
2562 return NULL;
2563
2564 arg0 = TypeConvertExpr(arg0, promotedType, lOpString(op));
2565 arg1 = TypeConvertExpr(arg1, promotedType, lOpString(op));
2566 if (arg0 == NULL || arg1 == NULL)
2567 return NULL;
2568 return this;
2569 }
2570 case Lt:
2571 case Gt:
2572 case Le:
2573 case Ge:
2574 case Equal:
2575 case NotEqual: {
2576 const PointerType *pt0 = CastType<PointerType>(type0);
2577 const PointerType *pt1 = CastType<PointerType>(type1);
2578
2579 // Convert '0' in expressions where the other expression is a
2580 // pointer type to a NULL pointer.
2581 if (pt0 != NULL && lIsAllIntZeros(arg1)) {
2582 arg1 = new NullPointerExpr(pos);
2583 type1 = arg1->GetType();
2584 pt1 = CastType<PointerType>(type1);
2585 } else if (pt1 != NULL && lIsAllIntZeros(arg0)) {
2586 arg0 = new NullPointerExpr(pos);
2587 type0 = arg1->GetType();
2588 pt0 = CastType<PointerType>(type0);
2589 }
2590
2591 if (pt0 == NULL && pt1 == NULL) {
2592 if (!type0->IsBoolType() && !type0->IsNumericType()) {
2593 Error(arg0->pos,
2594 "First operand to operator \"%s\" is of "
2595 "non-comparable type \"%s\".",
2596 lOpString(op), type0->GetString().c_str());
2597 return NULL;
2598 }
2599 if (!type1->IsBoolType() && !type1->IsNumericType()) {
2600 Error(arg1->pos,
2601 "Second operand to operator \"%s\" is of "
2602 "non-comparable type \"%s\".",
2603 lOpString(op), type1->GetString().c_str());
2604 return NULL;
2605 }
2606 }
2607
2608 const Type *promotedType = Type::MoreGeneralType(type0, type1, arg0->pos, lOpString(op));
2609 if (promotedType == NULL)
2610 return NULL;
2611
2612 arg0 = TypeConvertExpr(arg0, promotedType, lOpString(op));
2613 arg1 = TypeConvertExpr(arg1, promotedType, lOpString(op));
2614 if (arg0 == NULL || arg1 == NULL)
2615 return NULL;
2616 return this;
2617 }
2618 case LogicalAnd:
2619 case LogicalOr: {
2620 // For now, we just type convert to boolean types, of the same
2621 // variability as the original types. (When generating code, it's
2622 // useful to have preserved the uniform/varying distinction.)
2623 const AtomicType *boolType0 = type0->IsUniformType() ? AtomicType::UniformBool : AtomicType::VaryingBool;
2624 const AtomicType *boolType1 = type1->IsUniformType() ? AtomicType::UniformBool : AtomicType::VaryingBool;
2625
2626 const Type *destType0 = NULL, *destType1 = NULL;
2627 const VectorType *vtype0 = CastType<VectorType>(type0);
2628 const VectorType *vtype1 = CastType<VectorType>(type1);
2629 if (vtype0 && vtype1) {
2630 int sz0 = vtype0->GetElementCount(), sz1 = vtype1->GetElementCount();
2631 if (sz0 != sz1) {
2632 Error(pos,
2633 "Can't do logical operation \"%s\" between vector types of "
2634 "different sizes (%d vs. %d).",
2635 lOpString(op), sz0, sz1);
2636 return NULL;
2637 }
2638 destType0 = new VectorType(boolType0, sz0);
2639 destType1 = new VectorType(boolType1, sz1);
2640 } else if (vtype0 != NULL) {
2641 destType0 = new VectorType(boolType0, vtype0->GetElementCount());
2642 destType1 = new VectorType(boolType1, vtype0->GetElementCount());
2643 } else if (vtype1 != NULL) {
2644 destType0 = new VectorType(boolType0, vtype1->GetElementCount());
2645 destType1 = new VectorType(boolType1, vtype1->GetElementCount());
2646 } else {
2647 destType0 = boolType0;
2648 destType1 = boolType1;
2649 }
2650
2651 arg0 = TypeConvertExpr(arg0, destType0, lOpString(op));
2652 arg1 = TypeConvertExpr(arg1, destType1, lOpString(op));
2653 if (arg0 == NULL || arg1 == NULL)
2654 return NULL;
2655 return this;
2656 }
2657 case Comma:
2658 return this;
2659 default:
2660 FATAL("logic error");
2661 return NULL;
2662 }
2663 }
2664
GetLValueType() const2665 const Type *BinaryExpr::GetLValueType() const {
2666 const Type *t = GetType();
2667 if (CastType<PointerType>(t) != NULL) {
2668 // Are we doing something like (basePtr + offset)[...] = ...
2669 return t;
2670 } else {
2671 return NULL;
2672 }
2673 }
2674
EstimateCost() const2675 int BinaryExpr::EstimateCost() const {
2676 if (llvm::dyn_cast<ConstExpr>(arg0) != NULL && llvm::dyn_cast<ConstExpr>(arg1) != NULL)
2677 return 0;
2678
2679 return (op == Div || op == Mod) ? COST_COMPLEX_ARITH_OP : COST_SIMPLE_ARITH_LOGIC_OP;
2680 }
2681
Print() const2682 void BinaryExpr::Print() const {
2683 if (!arg0 || !arg1 || !GetType())
2684 return;
2685
2686 printf("[ %s ] (", GetType()->GetString().c_str());
2687 arg0->Print();
2688 printf(" %s ", lOpString(op));
2689 arg1->Print();
2690 printf(")");
2691 pos.Print();
2692 }
2693
lGetBinaryExprStorageConstant(const Type * type,const BinaryExpr * bExpr,bool isStorageType)2694 static std::pair<llvm::Constant *, bool> lGetBinaryExprStorageConstant(const Type *type, const BinaryExpr *bExpr,
2695 bool isStorageType) {
2696
2697 const BinaryExpr::Op op = bExpr->op;
2698 Expr *arg0 = bExpr->arg0;
2699 Expr *arg1 = bExpr->arg1;
2700
2701 // Are we doing something like (basePtr + offset)[...] = ... for a Global
2702 // Variable
2703 if (!bExpr->GetLValueType())
2704 return std::pair<llvm::Constant *, bool>(NULL, false);
2705
2706 // We are limiting cases to just addition and subtraction involving
2707 // pointer addresses
2708 // Case 1 : first argument is a pointer address.
2709 // In this case as long as the second argument is a constant value, we are fine
2710 // Case 2 : second argument is a pointer address.
2711 // In this case, it has to be an addition with first argument as
2712 // a constant value.
2713 if (!((op == BinaryExpr::Op::Add) || (op == BinaryExpr::Op::Sub)))
2714 return std::pair<llvm::Constant *, bool>(NULL, false);
2715 if (op == BinaryExpr::Op::Sub) {
2716 // Ignore cases where subtrahend is a PointerType
2717 // Eg. b - 5 is valid but 5 - b is not.
2718 if (CastType<PointerType>(arg1->GetType()))
2719 return std::pair<llvm::Constant *, bool>(NULL, false);
2720 }
2721
2722 // 'isNotValidForMultiTargetGlobal' is required to let the caller know
2723 // that the llvm::constant value returned cannot be used in case of
2724 // multi-target compilation for initialization of globals. This is due
2725 // to different constant values for different targets, i.e. computation
2726 // involving sizeof() of varying types.
2727 // Since converting expr to constant can be a recursive process, we need
2728 // to ensure that if the flag is set by any expr in the chain, it's
2729 // reflected in the final return value.
2730 bool isNotValidForMultiTargetGlobal = false;
2731 if (const PointerType *pt0 = CastType<PointerType>(arg0->GetType())) {
2732 std::pair<llvm::Constant *, bool> c1Pair;
2733 if (isStorageType)
2734 c1Pair = arg0->GetStorageConstant(pt0);
2735 else
2736 c1Pair = arg0->GetConstant(pt0);
2737 llvm::Constant *c1 = c1Pair.first;
2738 isNotValidForMultiTargetGlobal = isNotValidForMultiTargetGlobal || c1Pair.second;
2739 ConstExpr *cExpr = llvm::dyn_cast<ConstExpr>(arg1);
2740 if ((cExpr == NULL) || (c1 == NULL))
2741 return std::pair<llvm::Constant *, bool>(NULL, false);
2742 std::pair<llvm::Constant *, bool> c2Pair;
2743 if (isStorageType)
2744 c2Pair = cExpr->GetStorageConstant(cExpr->GetType());
2745 else
2746 c2Pair = cExpr->GetConstant(cExpr->GetType());
2747 llvm::Constant *c2 = c2Pair.first;
2748 isNotValidForMultiTargetGlobal = isNotValidForMultiTargetGlobal || c2Pair.second;
2749 if (op == BinaryExpr::Op::Sub)
2750 c2 = llvm::ConstantExpr::getNeg(c2);
2751 llvm::Constant *c = llvm::ConstantExpr::getGetElementPtr(PTYPE(c1), c1, c2);
2752 return std::pair<llvm::Constant *, bool>(c, isNotValidForMultiTargetGlobal);
2753 } else if (const PointerType *pt1 = CastType<PointerType>(arg1->GetType())) {
2754 std::pair<llvm::Constant *, bool> c1Pair;
2755 if (isStorageType)
2756 c1Pair = arg1->GetStorageConstant(pt1);
2757 else
2758 c1Pair = arg1->GetConstant(pt1);
2759 llvm::Constant *c1 = c1Pair.first;
2760 isNotValidForMultiTargetGlobal = isNotValidForMultiTargetGlobal || c1Pair.second;
2761 ConstExpr *cExpr = llvm::dyn_cast<ConstExpr>(arg0);
2762 if ((cExpr == NULL) || (c1 == NULL))
2763 return std::pair<llvm::Constant *, bool>(NULL, false);
2764 std::pair<llvm::Constant *, bool> c2Pair;
2765 if (isStorageType)
2766 c2Pair = cExpr->GetStorageConstant(cExpr->GetType());
2767 else
2768 c2Pair = cExpr->GetConstant(cExpr->GetType());
2769 llvm::Constant *c2 = c2Pair.first;
2770 isNotValidForMultiTargetGlobal = isNotValidForMultiTargetGlobal || c2Pair.second;
2771 llvm::Constant *c = llvm::ConstantExpr::getGetElementPtr(PTYPE(c1), c1, c2);
2772 return std::pair<llvm::Constant *, bool>(c, isNotValidForMultiTargetGlobal);
2773 }
2774
2775 return std::pair<llvm::Constant *, bool>(NULL, false);
2776 }
2777
GetStorageConstant(const Type * type) const2778 std::pair<llvm::Constant *, bool> BinaryExpr::GetStorageConstant(const Type *type) const {
2779 return lGetBinaryExprStorageConstant(type, this, true);
2780 }
2781
GetConstant(const Type * type) const2782 std::pair<llvm::Constant *, bool> BinaryExpr::GetConstant(const Type *type) const {
2783
2784 return lGetBinaryExprStorageConstant(type, this, false);
2785 }
2786 ///////////////////////////////////////////////////////////////////////////
2787 // AssignExpr
2788
lOpString(AssignExpr::Op op)2789 static const char *lOpString(AssignExpr::Op op) {
2790 switch (op) {
2791 case AssignExpr::Assign:
2792 return "assignment operator";
2793 case AssignExpr::MulAssign:
2794 return "*=";
2795 case AssignExpr::DivAssign:
2796 return "/=";
2797 case AssignExpr::ModAssign:
2798 return "%%=";
2799 case AssignExpr::AddAssign:
2800 return "+=";
2801 case AssignExpr::SubAssign:
2802 return "-=";
2803 case AssignExpr::ShlAssign:
2804 return "<<=";
2805 case AssignExpr::ShrAssign:
2806 return ">>=";
2807 case AssignExpr::AndAssign:
2808 return "&=";
2809 case AssignExpr::XorAssign:
2810 return "^=";
2811 case AssignExpr::OrAssign:
2812 return "|=";
2813 default:
2814 FATAL("Missing op in lOpstring");
2815 return "";
2816 }
2817 }
2818
2819 /** Emit code to do an "assignment + operation" operator, e.g. "+=".
2820 */
lEmitOpAssign(AssignExpr::Op op,Expr * arg0,Expr * arg1,const Type * type,Symbol * baseSym,SourcePos pos,FunctionEmitContext * ctx)2821 static llvm::Value *lEmitOpAssign(AssignExpr::Op op, Expr *arg0, Expr *arg1, const Type *type, Symbol *baseSym,
2822 SourcePos pos, FunctionEmitContext *ctx) {
2823 llvm::Value *lv = arg0->GetLValue(ctx);
2824 if (!lv) {
2825 // FIXME: I think this test is unnecessary and that this case
2826 // should be caught during typechecking
2827 Error(pos, "Can't assign to left-hand side of expression.");
2828 return NULL;
2829 }
2830 const Type *lvalueType = arg0->GetLValueType();
2831 const Type *resultType = arg0->GetType();
2832 if (lvalueType == NULL || resultType == NULL)
2833 return NULL;
2834
2835 // Get the value on the right-hand side of the assignment+operation
2836 // operator and load the current value on the left-hand side.
2837 llvm::Value *rvalue = arg1->GetValue(ctx);
2838 llvm::Value *mask = lMaskForSymbol(baseSym, ctx);
2839 ctx->SetDebugPos(arg0->pos);
2840 llvm::Value *oldLHS = ctx->LoadInst(lv, mask, lvalueType);
2841 ctx->SetDebugPos(pos);
2842
2843 // Map the operator to the corresponding BinaryExpr::Op operator
2844 BinaryExpr::Op basicop;
2845 switch (op) {
2846 case AssignExpr::MulAssign:
2847 basicop = BinaryExpr::Mul;
2848 break;
2849 case AssignExpr::DivAssign:
2850 basicop = BinaryExpr::Div;
2851 break;
2852 case AssignExpr::ModAssign:
2853 basicop = BinaryExpr::Mod;
2854 break;
2855 case AssignExpr::AddAssign:
2856 basicop = BinaryExpr::Add;
2857 break;
2858 case AssignExpr::SubAssign:
2859 basicop = BinaryExpr::Sub;
2860 break;
2861 case AssignExpr::ShlAssign:
2862 basicop = BinaryExpr::Shl;
2863 break;
2864 case AssignExpr::ShrAssign:
2865 basicop = BinaryExpr::Shr;
2866 break;
2867 case AssignExpr::AndAssign:
2868 basicop = BinaryExpr::BitAnd;
2869 break;
2870 case AssignExpr::XorAssign:
2871 basicop = BinaryExpr::BitXor;
2872 break;
2873 case AssignExpr::OrAssign:
2874 basicop = BinaryExpr::BitOr;
2875 break;
2876 default:
2877 FATAL("logic error in lEmitOpAssign()");
2878 return NULL;
2879 }
2880
2881 // Emit the code to compute the new value
2882 llvm::Value *newValue = NULL;
2883 switch (op) {
2884 case AssignExpr::MulAssign:
2885 case AssignExpr::DivAssign:
2886 case AssignExpr::ModAssign:
2887 case AssignExpr::AddAssign:
2888 case AssignExpr::SubAssign:
2889 newValue = lEmitBinaryArith(basicop, oldLHS, rvalue, type, arg1->GetType(), ctx, pos);
2890 break;
2891 case AssignExpr::ShlAssign:
2892 case AssignExpr::ShrAssign:
2893 case AssignExpr::AndAssign:
2894 case AssignExpr::XorAssign:
2895 case AssignExpr::OrAssign:
2896 newValue = lEmitBinaryBitOp(basicop, oldLHS, rvalue, arg0->GetType()->IsUnsignedType(), ctx);
2897 break;
2898 default:
2899 FATAL("logic error in lEmitOpAssign");
2900 return NULL;
2901 }
2902
2903 // And store the result back to the lvalue.
2904 ctx->SetDebugPos(arg0->pos);
2905 lStoreAssignResult(newValue, lv, resultType, lvalueType, ctx, baseSym);
2906
2907 return newValue;
2908 }
2909
AssignExpr(AssignExpr::Op o,Expr * a,Expr * b,SourcePos p)2910 AssignExpr::AssignExpr(AssignExpr::Op o, Expr *a, Expr *b, SourcePos p) : Expr(p, AssignExprID), op(o) {
2911 lvalue = a;
2912 rvalue = b;
2913 }
2914
GetValue(FunctionEmitContext * ctx) const2915 llvm::Value *AssignExpr::GetValue(FunctionEmitContext *ctx) const {
2916 const Type *type = NULL;
2917 if (lvalue == NULL || rvalue == NULL || (type = GetType()) == NULL)
2918 return NULL;
2919 ctx->SetDebugPos(pos);
2920
2921 Symbol *baseSym = lvalue->GetBaseSymbol();
2922
2923 switch (op) {
2924 case Assign: {
2925 llvm::Value *ptr = lvalue->GetLValue(ctx);
2926 if (ptr == NULL) {
2927 Error(lvalue->pos, "Left hand side of assignment expression can't "
2928 "be assigned to.");
2929 return NULL;
2930 }
2931 const Type *ptrType = lvalue->GetLValueType();
2932 const Type *valueType = rvalue->GetType();
2933 if (ptrType == NULL || valueType == NULL) {
2934 AssertPos(pos, m->errorCount > 0);
2935 return NULL;
2936 }
2937
2938 llvm::Value *value = rvalue->GetValue(ctx);
2939 if (value == NULL) {
2940 AssertPos(pos, m->errorCount > 0);
2941 return NULL;
2942 }
2943
2944 ctx->SetDebugPos(lvalue->pos);
2945
2946 lStoreAssignResult(value, ptr, valueType, ptrType, ctx, baseSym);
2947
2948 return value;
2949 }
2950 case MulAssign:
2951 case DivAssign:
2952 case ModAssign:
2953 case AddAssign:
2954 case SubAssign:
2955 case ShlAssign:
2956 case ShrAssign:
2957 case AndAssign:
2958 case XorAssign:
2959 case OrAssign: {
2960 // This should be caught during type checking
2961 AssertPos(pos, !CastType<ArrayType>(type) && !CastType<StructType>(type));
2962 return lEmitOpAssign(op, lvalue, rvalue, type, baseSym, pos, ctx);
2963 }
2964 default:
2965 FATAL("logic error in AssignExpr::GetValue()");
2966 return NULL;
2967 }
2968 }
2969
Optimize()2970 Expr *AssignExpr::Optimize() {
2971 if (lvalue == NULL || rvalue == NULL)
2972 return NULL;
2973 return this;
2974 }
2975
GetType() const2976 const Type *AssignExpr::GetType() const { return lvalue ? lvalue->GetType() : NULL; }
2977
2978 /** Recursively checks a structure type to see if it (or any struct type
2979 that it holds) has a const-qualified member. */
lCheckForConstStructMember(SourcePos pos,const StructType * structType,const StructType * initialType)2980 static bool lCheckForConstStructMember(SourcePos pos, const StructType *structType, const StructType *initialType) {
2981 for (int i = 0; i < structType->GetElementCount(); ++i) {
2982 const Type *t = structType->GetElementType(i);
2983 if (t->IsConstType()) {
2984 if (structType == initialType)
2985 Error(pos,
2986 "Illegal to assign to type \"%s\" due to element "
2987 "\"%s\" with type \"%s\".",
2988 structType->GetString().c_str(), structType->GetElementName(i).c_str(), t->GetString().c_str());
2989 else
2990 Error(pos,
2991 "Illegal to assign to type \"%s\" in type \"%s\" "
2992 "due to element \"%s\" with type \"%s\".",
2993 structType->GetString().c_str(), initialType->GetString().c_str(),
2994 structType->GetElementName(i).c_str(), t->GetString().c_str());
2995 return true;
2996 }
2997
2998 const StructType *st = CastType<StructType>(t);
2999 if (st != NULL && lCheckForConstStructMember(pos, st, initialType))
3000 return true;
3001 }
3002 return false;
3003 }
3004
TypeCheck()3005 Expr *AssignExpr::TypeCheck() {
3006 if (lvalue == NULL || rvalue == NULL)
3007 return NULL;
3008
3009 bool lvalueIsReference = CastType<ReferenceType>(lvalue->GetType()) != NULL;
3010 if (lvalueIsReference)
3011 lvalue = new RefDerefExpr(lvalue, lvalue->pos);
3012
3013 if (PossiblyResolveFunctionOverloads(rvalue, lvalue->GetType()) == false) {
3014 Error(pos, "Unable to find overloaded function for function "
3015 "pointer assignment.");
3016 return NULL;
3017 }
3018
3019 const Type *lhsType = lvalue->GetType();
3020 if (lhsType == NULL) {
3021 AssertPos(pos, m->errorCount > 0);
3022 return NULL;
3023 }
3024
3025 if (lhsType->IsConstType()) {
3026 Error(lvalue->pos,
3027 "Can't assign to type \"%s\" on left-hand side of "
3028 "expression.",
3029 lhsType->GetString().c_str());
3030 return NULL;
3031 }
3032
3033 if (CastType<PointerType>(lhsType) != NULL) {
3034 if (op == AddAssign || op == SubAssign) {
3035 if (PointerType::IsVoidPointer(lhsType)) {
3036 Error(pos,
3037 "Illegal to perform pointer arithmetic on \"%s\" "
3038 "type.",
3039 lhsType->GetString().c_str());
3040 return NULL;
3041 }
3042
3043 const Type *deltaType = g->target->is32Bit() ? AtomicType::UniformInt32 : AtomicType::UniformInt64;
3044 if (lhsType->IsVaryingType())
3045 deltaType = deltaType->GetAsVaryingType();
3046 rvalue = TypeConvertExpr(rvalue, deltaType, lOpString(op));
3047 } else if (op == Assign)
3048 rvalue = TypeConvertExpr(rvalue, lhsType, "assignment");
3049 else {
3050 Error(lvalue->pos, "Assignment operator \"%s\" is illegal with pointer types.", lOpString(op));
3051 return NULL;
3052 }
3053 } else if (CastType<ArrayType>(lhsType) != NULL) {
3054 Error(lvalue->pos, "Illegal to assign to array type \"%s\".", lhsType->GetString().c_str());
3055 return NULL;
3056 } else
3057 rvalue = TypeConvertExpr(rvalue, lhsType, lOpString(op));
3058
3059 if (rvalue == NULL)
3060 return NULL;
3061
3062 if (lhsType->IsFloatType() == true &&
3063 (op == ShlAssign || op == ShrAssign || op == AndAssign || op == XorAssign || op == OrAssign)) {
3064 Error(pos,
3065 "Illegal to use %s operator with floating-point "
3066 "operands.",
3067 lOpString(op));
3068 return NULL;
3069 }
3070
3071 const StructType *st = CastType<StructType>(lhsType);
3072 if (st != NULL) {
3073 // Make sure we're not assigning to a struct that has a constant member
3074 if (lCheckForConstStructMember(pos, st, st))
3075 return NULL;
3076
3077 if (op != Assign) {
3078 Error(lvalue->pos,
3079 "Assignment operator \"%s\" is illegal with struct "
3080 "type \"%s\".",
3081 lOpString(op), st->GetString().c_str());
3082 return NULL;
3083 }
3084 }
3085 return this;
3086 }
3087
EstimateCost() const3088 int AssignExpr::EstimateCost() const {
3089 if (op == Assign)
3090 return COST_ASSIGN;
3091 if (op == DivAssign || op == ModAssign)
3092 return COST_ASSIGN + COST_COMPLEX_ARITH_OP;
3093 else
3094 return COST_ASSIGN + COST_SIMPLE_ARITH_LOGIC_OP;
3095 }
3096
Print() const3097 void AssignExpr::Print() const {
3098 if (!lvalue || !rvalue || !GetType())
3099 return;
3100
3101 printf("[%s] assign (", GetType()->GetString().c_str());
3102 lvalue->Print();
3103 printf(" %s ", lOpString(op));
3104 rvalue->Print();
3105 printf(")");
3106 pos.Print();
3107 }
3108
3109 ///////////////////////////////////////////////////////////////////////////
3110 // SelectExpr
3111
SelectExpr(Expr * t,Expr * e1,Expr * e2,SourcePos p)3112 SelectExpr::SelectExpr(Expr *t, Expr *e1, Expr *e2, SourcePos p) : Expr(p, SelectExprID) {
3113 test = t;
3114 expr1 = e1;
3115 expr2 = e2;
3116 }
3117
3118 /** Emit code to select between two varying values based on a varying test
3119 value.
3120 */
lEmitVaryingSelect(FunctionEmitContext * ctx,llvm::Value * test,llvm::Value * expr1,llvm::Value * expr2,const Type * type)3121 static llvm::Value *lEmitVaryingSelect(FunctionEmitContext *ctx, llvm::Value *test, llvm::Value *expr1,
3122 llvm::Value *expr2, const Type *type) {
3123
3124 llvm::Value *resultPtr = ctx->AllocaInst(type, "selectexpr_tmp");
3125 Assert(resultPtr != NULL);
3126 // Don't need to worry about masking here
3127 ctx->StoreInst(expr2, resultPtr, type, type->IsUniformType());
3128 // Use masking to conditionally store the expr1 values
3129 Assert(resultPtr->getType() == PointerType::GetUniform(type)->LLVMStorageType(g->ctx));
3130 ctx->StoreInst(expr1, resultPtr, test, type, PointerType::GetUniform(type));
3131 return ctx->LoadInst(resultPtr, type, "selectexpr_final");
3132 }
3133
lEmitSelectExprCode(FunctionEmitContext * ctx,llvm::Value * testVal,llvm::Value * oldMask,llvm::Value * fullMask,Expr * expr,llvm::Value * exprPtr)3134 static void lEmitSelectExprCode(FunctionEmitContext *ctx, llvm::Value *testVal, llvm::Value *oldMask,
3135 llvm::Value *fullMask, Expr *expr, llvm::Value *exprPtr) {
3136 llvm::BasicBlock *bbEval = ctx->CreateBasicBlock("select_eval_expr", ctx->GetCurrentBasicBlock());
3137 llvm::BasicBlock *bbDone = ctx->CreateBasicBlock("select_done", bbEval);
3138
3139 // Check to see if the test was true for any of the currently executing
3140 // program instances.
3141 llvm::Value *testAndFullMask = ctx->BinaryOperator(llvm::Instruction::And, testVal, fullMask, "test&mask");
3142 llvm::Value *anyOn = ctx->Any(testAndFullMask);
3143 ctx->BranchInst(bbEval, bbDone, anyOn);
3144
3145 ctx->SetCurrentBasicBlock(bbEval);
3146 llvm::Value *testAndMask = ctx->BinaryOperator(llvm::Instruction::And, testVal, oldMask, "test&mask");
3147 ctx->SetInternalMask(testAndMask);
3148 llvm::Value *exprVal = expr->GetValue(ctx);
3149 ctx->StoreInst(exprVal, exprPtr, expr->GetType(), expr->GetType()->IsUniformType());
3150 ctx->BranchInst(bbDone);
3151
3152 ctx->SetCurrentBasicBlock(bbDone);
3153 }
3154
HasAmbiguousVariability(std::vector<const Expr * > & warn) const3155 bool SelectExpr::HasAmbiguousVariability(std::vector<const Expr *> &warn) const {
3156 bool isExpr1Amb = false;
3157 bool isExpr2Amb = false;
3158 if (expr1 != NULL) {
3159 const Type *type1 = expr1->GetType();
3160 if (expr1->HasAmbiguousVariability(warn)) {
3161 isExpr1Amb = true;
3162 } else if ((type1 != NULL) && (type1->IsVaryingType())) {
3163 // If either expr is varying, then the expression is un-ambiguously varying.
3164 return false;
3165 }
3166 }
3167 if (expr2 != NULL) {
3168 const Type *type2 = expr2->GetType();
3169 if (expr2->HasAmbiguousVariability(warn)) {
3170 isExpr2Amb = true;
3171 } else if ((type2 != NULL) && (type2->IsVaryingType())) {
3172 // If either arg is varying, then the expression is un-ambiguously varying.
3173 return false;
3174 }
3175 }
3176 if (isExpr1Amb || isExpr2Amb) {
3177 return true;
3178 }
3179
3180 return false;
3181 }
3182
GetValue(FunctionEmitContext * ctx) const3183 llvm::Value *SelectExpr::GetValue(FunctionEmitContext *ctx) const {
3184 if (!expr1 || !expr2 || !test)
3185 return NULL;
3186
3187 ctx->SetDebugPos(pos);
3188
3189 const Type *testType = test->GetType()->GetAsNonConstType();
3190 // This should be taken care of during typechecking
3191 AssertPos(pos, Type::Equal(testType->GetBaseType(), AtomicType::UniformBool) ||
3192 Type::Equal(testType->GetBaseType(), AtomicType::VaryingBool));
3193
3194 const Type *type = expr1->GetType();
3195
3196 if (Type::Equal(testType, AtomicType::UniformBool)) {
3197 // Simple case of a single uniform bool test expression; we just
3198 // want one of the two expressions. In this case, we can be
3199 // careful to evaluate just the one of the expressions that we need
3200 // the value of so that if the other one has side-effects or
3201 // accesses invalid memory, it doesn't execute.
3202 llvm::Value *testVal = test->GetValue(ctx);
3203 llvm::BasicBlock *testTrue = ctx->CreateBasicBlock("select_true", ctx->GetCurrentBasicBlock());
3204 llvm::BasicBlock *testFalse = ctx->CreateBasicBlock("select_false", testTrue);
3205 llvm::BasicBlock *testDone = ctx->CreateBasicBlock("select_done", testFalse);
3206 ctx->BranchInst(testTrue, testFalse, testVal);
3207
3208 ctx->SetCurrentBasicBlock(testTrue);
3209 llvm::Value *expr1Val = expr1->GetValue(ctx);
3210 // Note that truePred won't be necessarily equal to testTrue, in
3211 // case the expr1->GetValue() call changes the current basic block.
3212 llvm::BasicBlock *truePred = ctx->GetCurrentBasicBlock();
3213 ctx->BranchInst(testDone);
3214
3215 ctx->SetCurrentBasicBlock(testFalse);
3216 llvm::Value *expr2Val = expr2->GetValue(ctx);
3217 // See comment above truePred for why we can't just assume we're in
3218 // the testFalse basic block here.
3219 llvm::BasicBlock *falsePred = ctx->GetCurrentBasicBlock();
3220 ctx->BranchInst(testDone);
3221
3222 ctx->SetCurrentBasicBlock(testDone);
3223 llvm::PHINode *ret = ctx->PhiNode(expr1Val->getType(), 2, "select");
3224 ret->addIncoming(expr1Val, truePred);
3225 ret->addIncoming(expr2Val, falsePred);
3226 return ret;
3227 } else if (CastType<VectorType>(testType) == NULL) {
3228 // the test is a varying bool type
3229 llvm::Value *testVal = test->GetValue(ctx);
3230 AssertPos(pos, testVal->getType() == LLVMTypes::MaskType);
3231 llvm::Value *oldMask = ctx->GetInternalMask();
3232 llvm::Value *fullMask = ctx->GetFullMask();
3233
3234 // We don't want to incur the overhead for short-circuit evaluation
3235 // for expressions that are both computationally simple and safe to
3236 // run with an "all off" mask.
3237 int threshold = g->target->isGenXTarget() ? PREDICATE_SAFE_SHORT_CIRC_GENX_STATEMENT_COST
3238 : PREDICATE_SAFE_IF_STATEMENT_COST;
3239 bool shortCircuit1 = (::EstimateCost(expr1) > threshold || SafeToRunWithMaskAllOff(expr1) == false);
3240 bool shortCircuit2 = (::EstimateCost(expr2) > threshold || SafeToRunWithMaskAllOff(expr2) == false);
3241
3242 Debug(expr1->pos, "%sshort circuiting evaluation for select expr", shortCircuit1 ? "" : "Not ");
3243 Debug(expr2->pos, "%sshort circuiting evaluation for select expr", shortCircuit2 ? "" : "Not ");
3244
3245 // Temporary storage to store the values computed for each
3246 // expression, if any. (These stay as uninitialized memory if we
3247 // short circuit around the corresponding expression.)
3248 llvm::Value *expr1Ptr = ctx->AllocaInst(expr1->GetType());
3249 llvm::Value *expr2Ptr = ctx->AllocaInst(expr1->GetType());
3250
3251 if (shortCircuit1)
3252 lEmitSelectExprCode(ctx, testVal, oldMask, fullMask, expr1, expr1Ptr);
3253 else {
3254 ctx->SetInternalMaskAnd(oldMask, testVal);
3255 llvm::Value *expr1Val = expr1->GetValue(ctx);
3256 ctx->StoreInst(expr1Val, expr1Ptr, expr1->GetType(), expr1->GetType()->IsUniformType());
3257 }
3258
3259 if (shortCircuit2) {
3260 llvm::Value *notTest = ctx->NotOperator(testVal);
3261 lEmitSelectExprCode(ctx, notTest, oldMask, fullMask, expr2, expr2Ptr);
3262 } else {
3263 ctx->SetInternalMaskAndNot(oldMask, testVal);
3264 llvm::Value *expr2Val = expr2->GetValue(ctx);
3265 ctx->StoreInst(expr2Val, expr2Ptr, expr2->GetType(), expr2->GetType()->IsUniformType());
3266 }
3267
3268 ctx->SetInternalMask(oldMask);
3269 llvm::Value *expr1Val = ctx->LoadInst(expr1Ptr, expr1->GetType());
3270 llvm::Value *expr2Val = ctx->LoadInst(expr2Ptr, expr2->GetType());
3271 return lEmitVaryingSelect(ctx, testVal, expr1Val, expr2Val, type);
3272 } else {
3273 // FIXME? Short-circuiting doesn't work in the case of
3274 // vector-valued test expressions. (We could also just prohibit
3275 // these and place the issue in the user's hands...)
3276 llvm::Value *testVal = test->GetValue(ctx);
3277 llvm::Value *expr1Val = expr1->GetValue(ctx);
3278 llvm::Value *expr2Val = expr2->GetValue(ctx);
3279
3280 ctx->SetDebugPos(pos);
3281 const VectorType *vt = CastType<VectorType>(type);
3282 // Things that typechecking should have caught
3283 AssertPos(pos, vt != NULL);
3284 AssertPos(pos, CastType<VectorType>(testType) != NULL &&
3285 (CastType<VectorType>(testType)->GetElementCount() == vt->GetElementCount()));
3286
3287 // Do an element-wise select
3288 llvm::Value *result = llvm::UndefValue::get(type->LLVMType(g->ctx));
3289 for (int i = 0; i < vt->GetElementCount(); ++i) {
3290 llvm::Value *ti = ctx->ExtractInst(testVal, i);
3291 llvm::Value *e1i = ctx->ExtractInst(expr1Val, i);
3292 llvm::Value *e2i = ctx->ExtractInst(expr2Val, i);
3293 llvm::Value *sel = NULL;
3294 if (testType->IsUniformType()) {
3295 // Extracting uniform vector bool to uniform bool require
3296 // switching from i8 -> i1
3297 ti = ctx->SwitchBoolSize(ti, LLVMTypes::BoolType);
3298 sel = ctx->SelectInst(ti, e1i, e2i);
3299 } else {
3300 // Extracting varying vector bools to varying bools require
3301 // switching from <WIDTH x i8> -> <WIDTH x MaskType>
3302 ti = ctx->SwitchBoolSize(ti, LLVMTypes::BoolVectorType);
3303 sel = lEmitVaryingSelect(ctx, ti, e1i, e2i, vt->GetElementType());
3304 }
3305 result = ctx->InsertInst(result, sel, i);
3306 }
3307 return result;
3308 }
3309 }
3310
GetType() const3311 const Type *SelectExpr::GetType() const {
3312 if (!test || !expr1 || !expr2)
3313 return NULL;
3314
3315 const Type *testType = test->GetType();
3316 const Type *expr1Type = expr1->GetType();
3317 const Type *expr2Type = expr2->GetType();
3318
3319 if (!testType || !expr1Type || !expr2Type)
3320 return NULL;
3321
3322 bool becomesVarying = (testType->IsVaryingType() || expr1Type->IsVaryingType() || expr2Type->IsVaryingType());
3323 // if expr1 and expr2 have different vector sizes, typechecking should fail...
3324 int testVecSize = CastType<VectorType>(testType) != NULL ? CastType<VectorType>(testType)->GetElementCount() : 0;
3325 int expr1VecSize = CastType<VectorType>(expr1Type) != NULL ? CastType<VectorType>(expr1Type)->GetElementCount() : 0;
3326 AssertPos(pos, !(testVecSize != 0 && expr1VecSize != 0 && testVecSize != expr1VecSize));
3327
3328 int vectorSize = std::max(testVecSize, expr1VecSize);
3329 return Type::MoreGeneralType(expr1Type, expr2Type, Union(expr1->pos, expr2->pos), "select expression",
3330 becomesVarying, vectorSize);
3331 }
3332
3333 template <typename T>
lConstFoldSelect(const bool bv[],ConstExpr * constExpr1,ConstExpr * constExpr2,const Type * exprType,SourcePos pos)3334 Expr *lConstFoldSelect(const bool bv[], ConstExpr *constExpr1, ConstExpr *constExpr2, const Type *exprType,
3335 SourcePos pos) {
3336 T v1[ISPC_MAX_NVEC], v2[ISPC_MAX_NVEC];
3337 T result[ISPC_MAX_NVEC];
3338 int count = constExpr1->GetValues(v1);
3339 constExpr2->GetValues(v2);
3340 for (int i = 0; i < count; ++i)
3341 result[i] = bv[i] ? v1[i] : v2[i];
3342 return new ConstExpr(exprType, result, pos);
3343 }
3344
Optimize()3345 Expr *SelectExpr::Optimize() {
3346 if (test == NULL || expr1 == NULL || expr2 == NULL)
3347 return NULL;
3348
3349 ConstExpr *constTest = llvm::dyn_cast<ConstExpr>(test);
3350 if (constTest == NULL)
3351 return this;
3352
3353 // The test is a constant; see if we can resolve to one of the
3354 // expressions..
3355 bool bv[ISPC_MAX_NVEC];
3356 int count = constTest->GetValues(bv);
3357 if (count == 1)
3358 // Uniform test value; return the corresponding expression
3359 return (bv[0] == true) ? expr1 : expr2;
3360 else {
3361 // Varying test: see if all of the values are the same; if so, then
3362 // return the corresponding expression
3363 bool first = bv[0];
3364 bool mismatch = false;
3365 for (int i = 0; i < count; ++i)
3366 if (bv[i] != first) {
3367 mismatch = true;
3368 break;
3369 }
3370 if (mismatch == false)
3371 return (bv[0] == true) ? expr1 : expr2;
3372
3373 // Last chance: see if the two expressions are constants; if so,
3374 // then we can do an element-wise selection based on the constant
3375 // condition..
3376 ConstExpr *constExpr1 = llvm::dyn_cast<ConstExpr>(expr1);
3377 ConstExpr *constExpr2 = llvm::dyn_cast<ConstExpr>(expr2);
3378 if (constExpr1 == NULL || constExpr2 == NULL)
3379 return this;
3380
3381 AssertPos(pos, Type::Equal(constExpr1->GetType(), constExpr2->GetType()));
3382 const Type *exprType = constExpr1->GetType()->GetAsNonConstType();
3383 AssertPos(pos, exprType->IsVaryingType());
3384
3385 if (Type::Equal(exprType, AtomicType::VaryingInt8)) {
3386 return lConstFoldSelect<int8_t>(bv, constExpr1, constExpr2, exprType, pos);
3387 } else if (Type::Equal(exprType, AtomicType::VaryingUInt8)) {
3388 return lConstFoldSelect<uint8_t>(bv, constExpr1, constExpr2, exprType, pos);
3389 } else if (Type::Equal(exprType, AtomicType::VaryingInt16)) {
3390 return lConstFoldSelect<int16_t>(bv, constExpr1, constExpr2, exprType, pos);
3391 } else if (Type::Equal(exprType, AtomicType::VaryingUInt16)) {
3392 return lConstFoldSelect<uint16_t>(bv, constExpr1, constExpr2, exprType, pos);
3393 } else if (Type::Equal(exprType, AtomicType::VaryingInt32)) {
3394 return lConstFoldSelect<int32_t>(bv, constExpr1, constExpr2, exprType, pos);
3395 } else if (Type::Equal(exprType, AtomicType::VaryingUInt32)) {
3396 return lConstFoldSelect<uint32_t>(bv, constExpr1, constExpr2, exprType, pos);
3397 } else if (Type::Equal(exprType, AtomicType::VaryingInt64)) {
3398 return lConstFoldSelect<int64_t>(bv, constExpr1, constExpr2, exprType, pos);
3399 } else if (Type::Equal(exprType, AtomicType::VaryingUInt64)) {
3400 return lConstFoldSelect<uint64_t>(bv, constExpr1, constExpr2, exprType, pos);
3401 } else if (Type::Equal(exprType, AtomicType::VaryingFloat)) {
3402 return lConstFoldSelect<float>(bv, constExpr1, constExpr2, exprType, pos);
3403 } else if (Type::Equal(exprType, AtomicType::VaryingDouble)) {
3404 return lConstFoldSelect<bool>(bv, constExpr1, constExpr2, exprType, pos);
3405 } else if (Type::Equal(exprType, AtomicType::VaryingBool)) {
3406 return lConstFoldSelect<double>(bv, constExpr1, constExpr2, exprType, pos);
3407 }
3408
3409 return this;
3410 }
3411 }
3412
TypeCheck()3413 Expr *SelectExpr::TypeCheck() {
3414 if (test == NULL || expr1 == NULL || expr2 == NULL)
3415 return NULL;
3416
3417 const Type *type1 = expr1->GetType(), *type2 = expr2->GetType();
3418 if (!type1 || !type2)
3419 return NULL;
3420
3421 if (const ArrayType *at1 = CastType<ArrayType>(type1)) {
3422 expr1 = TypeConvertExpr(expr1, PointerType::GetUniform(at1->GetBaseType()), "select");
3423 if (expr1 == NULL)
3424 return NULL;
3425 type1 = expr1->GetType();
3426 }
3427 if (const ArrayType *at2 = CastType<ArrayType>(type2)) {
3428 expr2 = TypeConvertExpr(expr2, PointerType::GetUniform(at2->GetBaseType()), "select");
3429 if (expr2 == NULL)
3430 return NULL;
3431 type2 = expr2->GetType();
3432 }
3433
3434 const Type *testType = test->GetType();
3435 if (testType == NULL)
3436 return NULL;
3437 test = TypeConvertExpr(test, lMatchingBoolType(testType), "select");
3438 if (test == NULL)
3439 return NULL;
3440 testType = test->GetType();
3441
3442 int testVecSize = CastType<VectorType>(testType) ? CastType<VectorType>(testType)->GetElementCount() : 0;
3443 const Type *promotedType = Type::MoreGeneralType(type1, type2, Union(expr1->pos, expr2->pos), "select expression",
3444 testType->IsVaryingType(), testVecSize);
3445
3446 // If the promoted type is a ReferenceType, the expression type will be
3447 // the reference target type since SelectExpr is always a rvalue.
3448 if (CastType<ReferenceType>(promotedType) != NULL)
3449 promotedType = promotedType->GetReferenceTarget();
3450
3451 if (promotedType == NULL)
3452 return NULL;
3453
3454 expr1 = TypeConvertExpr(expr1, promotedType, "select");
3455 expr2 = TypeConvertExpr(expr2, promotedType, "select");
3456 if (expr1 == NULL || expr2 == NULL)
3457 return NULL;
3458
3459 return this;
3460 }
3461
EstimateCost() const3462 int SelectExpr::EstimateCost() const { return COST_SELECT; }
3463
Print() const3464 void SelectExpr::Print() const {
3465 if (!test || !expr1 || !expr2 || !GetType())
3466 return;
3467
3468 printf("[%s] (", GetType()->GetString().c_str());
3469 test->Print();
3470 printf(" ? ");
3471 expr1->Print();
3472 printf(" : ");
3473 expr2->Print();
3474 printf(")");
3475 pos.Print();
3476 }
3477
3478 ///////////////////////////////////////////////////////////////////////////
3479 // FunctionCallExpr
3480
FunctionCallExpr(Expr * f,ExprList * a,SourcePos p,bool il,Expr * lce[3])3481 FunctionCallExpr::FunctionCallExpr(Expr *f, ExprList *a, SourcePos p, bool il, Expr *lce[3])
3482 : Expr(p, FunctionCallExprID), isLaunch(il) {
3483 func = f;
3484 args = a;
3485 std::vector<const Expr *> warn;
3486 if (a->HasAmbiguousVariability(warn) == true) {
3487 for (auto w : warn) {
3488 const TypeCastExpr *tExpr = llvm::dyn_cast<TypeCastExpr>(w);
3489 tExpr->PrintAmbiguousVariability();
3490 }
3491 }
3492 if (lce != NULL) {
3493 launchCountExpr[0] = lce[0];
3494 launchCountExpr[1] = lce[1];
3495 launchCountExpr[2] = lce[2];
3496 } else
3497 launchCountExpr[0] = launchCountExpr[1] = launchCountExpr[2] = NULL;
3498 }
3499
lGetFunctionType(Expr * func)3500 static const FunctionType *lGetFunctionType(Expr *func) {
3501 if (func == NULL)
3502 return NULL;
3503
3504 const Type *type = func->GetType();
3505 if (type == NULL)
3506 return NULL;
3507
3508 const FunctionType *ftype = CastType<FunctionType>(type);
3509 if (ftype == NULL) {
3510 // Not a regular function symbol--is it a function pointer?
3511 if (CastType<PointerType>(type) != NULL)
3512 ftype = CastType<FunctionType>(type->GetBaseType());
3513 }
3514 return ftype;
3515 }
3516
GetValue(FunctionEmitContext * ctx) const3517 llvm::Value *FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const {
3518 if (func == NULL || args == NULL)
3519 return NULL;
3520
3521 ctx->SetDebugPos(pos);
3522
3523 llvm::Value *callee = func->GetValue(ctx);
3524
3525 if (callee == NULL) {
3526 AssertPos(pos, m->errorCount > 0);
3527 return NULL;
3528 }
3529
3530 const FunctionType *ft = lGetFunctionType(func);
3531 AssertPos(pos, ft != NULL);
3532 bool isVoidFunc = ft->GetReturnType()->IsVoidType();
3533
3534 // Automatically convert function call args to references if needed.
3535 // FIXME: this should move to the TypeCheck() method... (but the
3536 // GetLValue call below needs a FunctionEmitContext, which is
3537 // problematic...)
3538 std::vector<Expr *> callargs = args->exprs;
3539
3540 // Specifically, this can happen if there's an error earlier during
3541 // overload resolution.
3542 if ((int)callargs.size() > ft->GetNumParameters()) {
3543 AssertPos(pos, m->errorCount > 0);
3544 return NULL;
3545 }
3546
3547 for (unsigned int i = 0; i < callargs.size(); ++i) {
3548 Expr *argExpr = callargs[i];
3549 if (argExpr == NULL)
3550 continue;
3551
3552 const Type *paramType = ft->GetParameterType(i);
3553
3554 const Type *argLValueType = argExpr->GetLValueType();
3555 if (argLValueType != NULL && CastType<PointerType>(argLValueType) != NULL && argLValueType->IsVaryingType() &&
3556 CastType<ReferenceType>(paramType) != NULL) {
3557 Error(argExpr->pos,
3558 "Illegal to pass a \"varying\" lvalue to a "
3559 "reference parameter of type \"%s\".",
3560 paramType->GetString().c_str());
3561 return NULL;
3562 }
3563
3564 // Do whatever type conversion is needed
3565 argExpr = TypeConvertExpr(argExpr, paramType, "function call argument");
3566 if (argExpr == NULL)
3567 return NULL;
3568 callargs[i] = argExpr;
3569 }
3570
3571 // Fill in any default argument values needed.
3572 // FIXME: should we do this during type checking?
3573 for (int i = callargs.size(); i < ft->GetNumParameters(); ++i) {
3574 Expr *paramDefault = ft->GetParameterDefault(i);
3575 const Type *paramType = ft->GetParameterType(i);
3576 // FIXME: this type conv should happen when we create the function
3577 // type!
3578 Expr *d = TypeConvertExpr(paramDefault, paramType, "function call default argument");
3579 if (d == NULL)
3580 return NULL;
3581 callargs.push_back(d);
3582 }
3583
3584 // Now evaluate the values of all of the parameters being passed.
3585 std::vector<llvm::Value *> argVals;
3586 for (unsigned int i = 0; i < callargs.size(); ++i) {
3587 Expr *argExpr = callargs[i];
3588 if (argExpr == NULL)
3589 // give up; we hit an error earlier
3590 return NULL;
3591
3592 llvm::Value *argValue = argExpr->GetValue(ctx);
3593 if (argValue == NULL)
3594 // something went wrong in evaluating the argument's
3595 // expression, so give up on this
3596 return NULL;
3597
3598 argVals.push_back(argValue);
3599 }
3600
3601 llvm::Value *retVal = NULL;
3602 ctx->SetDebugPos(pos);
3603 if (ft->isTask) {
3604 AssertPos(pos, launchCountExpr[0] != NULL);
3605 llvm::Value *launchCount[3] = {launchCountExpr[0]->GetValue(ctx), launchCountExpr[1]->GetValue(ctx),
3606 launchCountExpr[2]->GetValue(ctx)};
3607
3608 if (launchCount[0] != NULL)
3609 ctx->LaunchInst(callee, argVals, launchCount);
3610 } else
3611 retVal = ctx->CallInst(callee, ft, argVals, isVoidFunc ? "" : "calltmp");
3612
3613 if (isVoidFunc)
3614 return NULL;
3615 else
3616 return retVal;
3617 }
3618
GetLValue(FunctionEmitContext * ctx) const3619 llvm::Value *FunctionCallExpr::GetLValue(FunctionEmitContext *ctx) const {
3620 if (GetLValueType() != NULL) {
3621 return GetValue(ctx);
3622 } else {
3623 // Only be a valid LValue type if the function
3624 // returns a pointer or reference.
3625 return NULL;
3626 }
3627 }
3628
FullResolveOverloads(Expr * func,ExprList * args,std::vector<const Type * > * argTypes,std::vector<bool> * argCouldBeNULL,std::vector<bool> * argIsConstant)3629 bool FullResolveOverloads(Expr *func, ExprList *args, std::vector<const Type *> *argTypes,
3630 std::vector<bool> *argCouldBeNULL, std::vector<bool> *argIsConstant) {
3631 for (unsigned int i = 0; i < args->exprs.size(); ++i) {
3632 Expr *expr = args->exprs[i];
3633 if (expr == NULL)
3634 return false;
3635 const Type *t = expr->GetType();
3636 if (t == NULL)
3637 return false;
3638 argTypes->push_back(t);
3639 argCouldBeNULL->push_back(lIsAllIntZeros(expr) || llvm::dyn_cast<NullPointerExpr>(expr));
3640 argIsConstant->push_back(llvm::dyn_cast<ConstExpr>(expr) || llvm::dyn_cast<NullPointerExpr>(expr));
3641 }
3642 return true;
3643 }
3644
GetType() const3645 const Type *FunctionCallExpr::GetType() const {
3646 std::vector<const Type *> argTypes;
3647 std::vector<bool> argCouldBeNULL, argIsConstant;
3648 if (FullResolveOverloads(func, args, &argTypes, &argCouldBeNULL, &argIsConstant) == true) {
3649 FunctionSymbolExpr *fse = llvm::dyn_cast<FunctionSymbolExpr>(func);
3650 if (fse != NULL) {
3651 fse->ResolveOverloads(args->pos, argTypes, &argCouldBeNULL, &argIsConstant);
3652 }
3653 }
3654 const FunctionType *ftype = lGetFunctionType(func);
3655 return ftype ? ftype->GetReturnType() : NULL;
3656 }
3657
GetLValueType() const3658 const Type *FunctionCallExpr::GetLValueType() const {
3659 const FunctionType *ftype = lGetFunctionType(func);
3660 if (ftype && (ftype->GetReturnType()->IsPointerType() || ftype->GetReturnType()->IsReferenceType())) {
3661 return ftype->GetReturnType();
3662 } else {
3663 // Only be a valid LValue type if the function
3664 // returns a pointer or reference.
3665 return NULL;
3666 }
3667 }
3668
Optimize()3669 Expr *FunctionCallExpr::Optimize() {
3670 if (func == NULL || args == NULL)
3671 return NULL;
3672 return this;
3673 }
3674
TypeCheck()3675 Expr *FunctionCallExpr::TypeCheck() {
3676 if (func == NULL || args == NULL)
3677 return NULL;
3678
3679 std::vector<const Type *> argTypes;
3680 std::vector<bool> argCouldBeNULL, argIsConstant;
3681
3682 if (FullResolveOverloads(func, args, &argTypes, &argCouldBeNULL, &argIsConstant) == false) {
3683 return NULL;
3684 }
3685
3686 FunctionSymbolExpr *fse = llvm::dyn_cast<FunctionSymbolExpr>(func);
3687 if (fse != NULL) {
3688 // Regular function call
3689 if (fse->ResolveOverloads(args->pos, argTypes, &argCouldBeNULL, &argIsConstant) == false)
3690 return NULL;
3691
3692 func = ::TypeCheck(fse);
3693 if (func == NULL)
3694 return NULL;
3695
3696 const FunctionType *ft = CastType<FunctionType>(func->GetType());
3697 if (ft == NULL) {
3698 const PointerType *pt = CastType<PointerType>(func->GetType());
3699 ft = (pt == NULL) ? NULL : CastType<FunctionType>(pt->GetBaseType());
3700 }
3701
3702 if (ft == NULL) {
3703 Error(pos, "Valid function name must be used for function call.");
3704 return NULL;
3705 }
3706
3707 if (ft->isTask) {
3708 if (!isLaunch)
3709 Error(pos, "\"launch\" expression needed to call function "
3710 "with \"task\" qualifier.");
3711 for (int k = 0; k < 3; k++) {
3712 if (!launchCountExpr[k])
3713 return NULL;
3714
3715 launchCountExpr[k] = TypeConvertExpr(launchCountExpr[k], AtomicType::UniformInt32, "task launch count");
3716 if (launchCountExpr[k] == NULL)
3717 return NULL;
3718 }
3719 } else {
3720 if (isLaunch) {
3721 Error(pos, "\"launch\" expression illegal with non-\"task\"-"
3722 "qualified function.");
3723 return NULL;
3724 }
3725 AssertPos(pos, launchCountExpr[0] == NULL);
3726 }
3727 } else {
3728 // Call through a function pointer
3729 const Type *fptrType = func->GetType();
3730 if (fptrType == NULL)
3731 return NULL;
3732
3733 // Make sure we do in fact have a function to call
3734 const FunctionType *funcType;
3735 if (CastType<PointerType>(fptrType) == NULL ||
3736 (funcType = CastType<FunctionType>(fptrType->GetBaseType())) == NULL) {
3737 Error(func->pos, "Must provide function name or function pointer for "
3738 "function call expression.");
3739 return NULL;
3740 }
3741
3742 // Make sure we don't have too many arguments for the function
3743 if ((int)argTypes.size() > funcType->GetNumParameters()) {
3744 Error(args->pos,
3745 "Too many parameter values provided in "
3746 "function call (%d provided, %d expected).",
3747 (int)argTypes.size(), funcType->GetNumParameters());
3748 return NULL;
3749 }
3750 // It's ok to have too few arguments, as long as the function's
3751 // default parameter values have started by the time we run out
3752 // of arguments
3753 if ((int)argTypes.size() < funcType->GetNumParameters() &&
3754 funcType->GetParameterDefault(argTypes.size()) == NULL) {
3755 Error(args->pos,
3756 "Too few parameter values provided in "
3757 "function call (%d provided, %d expected).",
3758 (int)argTypes.size(), funcType->GetNumParameters());
3759 return NULL;
3760 }
3761
3762 // Now make sure they can all type convert to the corresponding
3763 // parameter types..
3764 for (int i = 0; i < (int)argTypes.size(); ++i) {
3765 if (i < funcType->GetNumParameters()) {
3766 // make sure it can type convert
3767 const Type *paramType = funcType->GetParameterType(i);
3768 if (CanConvertTypes(argTypes[i], paramType) == false &&
3769 !(argCouldBeNULL[i] == true && CastType<PointerType>(paramType) != NULL)) {
3770 Error(args->exprs[i]->pos,
3771 "Can't convert argument of "
3772 "type \"%s\" to type \"%s\" for function call "
3773 "argument.",
3774 argTypes[i]->GetString().c_str(), paramType->GetString().c_str());
3775 return NULL;
3776 }
3777 } else
3778 // Otherwise the parameter default saves us. It should
3779 // be there for sure, given the check right above the
3780 // for loop.
3781 AssertPos(pos, funcType->GetParameterDefault(i) != NULL);
3782 }
3783
3784 if (fptrType->IsVaryingType()) {
3785 const Type *retType = funcType->GetReturnType();
3786 if (retType->IsVoidType() == false && retType->IsUniformType()) {
3787 Error(pos,
3788 "Illegal to call a varying function pointer that "
3789 "points to a function with a uniform return type \"%s\".",
3790 funcType->GetReturnType()->GetString().c_str());
3791 return NULL;
3792 }
3793 }
3794 }
3795
3796 if (func == NULL || args == NULL)
3797 return NULL;
3798 return this;
3799 }
3800
EstimateCost() const3801 int FunctionCallExpr::EstimateCost() const {
3802 if (isLaunch)
3803 return COST_TASK_LAUNCH;
3804
3805 const Type *type = func->GetType();
3806 if (type == NULL)
3807 return 0;
3808
3809 const PointerType *pt = CastType<PointerType>(type);
3810 if (pt != NULL)
3811 type = type->GetBaseType();
3812
3813 const FunctionType *ftype = CastType<FunctionType>(type);
3814 if (ftype != NULL && ftype->costOverride > -1)
3815 return ftype->costOverride;
3816
3817 if (pt != NULL)
3818 return pt->IsUniformType() ? COST_FUNPTR_UNIFORM : COST_FUNPTR_VARYING;
3819 else
3820 return COST_FUNCALL;
3821 }
3822
Print() const3823 void FunctionCallExpr::Print() const {
3824 if (!func || !args || !GetType())
3825 return;
3826
3827 printf("[%s] funcall %s ", GetType()->GetString().c_str(), isLaunch ? "launch" : "");
3828 func->Print();
3829 printf(" args (");
3830 args->Print();
3831 printf(")");
3832 pos.Print();
3833 }
3834
3835 ///////////////////////////////////////////////////////////////////////////
3836 // ExprList
3837
HasAmbiguousVariability(std::vector<const Expr * > & warn) const3838 bool ExprList::HasAmbiguousVariability(std::vector<const Expr *> &warn) const {
3839 bool hasAmbiguousVariability = false;
3840 for (unsigned int i = 0; i < exprs.size(); ++i) {
3841 if (exprs[i] != NULL) {
3842 hasAmbiguousVariability |= exprs[i]->HasAmbiguousVariability(warn);
3843 }
3844 }
3845 return hasAmbiguousVariability;
3846 }
3847
GetValue(FunctionEmitContext * ctx) const3848 llvm::Value *ExprList::GetValue(FunctionEmitContext *ctx) const {
3849 FATAL("ExprList::GetValue() should never be called");
3850 return NULL;
3851 }
3852
GetType() const3853 const Type *ExprList::GetType() const {
3854 FATAL("ExprList::GetType() should never be called");
3855 return NULL;
3856 }
3857
Optimize()3858 ExprList *ExprList::Optimize() { return this; }
3859
TypeCheck()3860 ExprList *ExprList::TypeCheck() { return this; }
3861
lGetExprListConstant(const Type * type,const ExprList * eList,bool isStorageType)3862 static std::pair<llvm::Constant *, bool> lGetExprListConstant(const Type *type, const ExprList *eList,
3863 bool isStorageType) {
3864 std::vector<Expr *> exprs = eList->exprs;
3865 SourcePos pos = eList->pos;
3866 bool isVaryingInit = false;
3867 bool isNotValidForMultiTargetGlobal = false;
3868 if (exprs.size() == 1 && (CastType<AtomicType>(type) != NULL || CastType<EnumType>(type) != NULL ||
3869 CastType<PointerType>(type) != NULL)) {
3870 if (isStorageType)
3871 return exprs[0]->GetStorageConstant(type);
3872 else
3873 return exprs[0]->GetConstant(type);
3874 }
3875
3876 const CollectionType *collectionType = CastType<CollectionType>(type);
3877 if (collectionType == NULL) {
3878 if (type->IsVaryingType() == true) {
3879 isVaryingInit = true;
3880 } else
3881 return std::pair<llvm::Constant *, bool>(NULL, false);
3882 }
3883
3884 std::string name;
3885 if (CastType<StructType>(type) != NULL)
3886 name = "struct";
3887 else if (CastType<ArrayType>(type) != NULL)
3888 name = "array";
3889 else if (CastType<VectorType>(type) != NULL)
3890 name = "vector";
3891 else if (isVaryingInit == true)
3892 name = "varying";
3893 else
3894 FATAL("Unexpected CollectionType in lGetExprListConstant");
3895
3896 int elementCount = (isVaryingInit == true) ? g->target->getVectorWidth() : collectionType->GetElementCount();
3897 if ((int)exprs.size() > elementCount) {
3898 const Type *errType = (isVaryingInit == true) ? type : collectionType;
3899 Error(pos,
3900 "Initializer list for %s \"%s\" must have no more than %d "
3901 "elements (has %d).",
3902 name.c_str(), errType->GetString().c_str(), elementCount, (int)exprs.size());
3903 return std::pair<llvm::Constant *, bool>(NULL, false);
3904 } else if ((isVaryingInit == true) && ((int)exprs.size() < elementCount)) {
3905 Error(pos,
3906 "Initializer list for %s \"%s\" must have %d "
3907 "elements (has %d).",
3908 name.c_str(), type->GetString().c_str(), elementCount, (int)exprs.size());
3909 return std::pair<llvm::Constant *, bool>(NULL, false);
3910 }
3911
3912 std::vector<llvm::Constant *> cv;
3913 for (unsigned int i = 0; i < exprs.size(); ++i) {
3914 if (exprs[i] == NULL)
3915 return std::pair<llvm::Constant *, bool>(NULL, false);
3916 const Type *elementType =
3917 (isVaryingInit == true) ? type->GetAsUniformType() : collectionType->GetElementType(i);
3918
3919 Expr *expr = exprs[i];
3920
3921 if (llvm::dyn_cast<ExprList>(expr) == NULL) {
3922 // If there's a simple type conversion from the type of this
3923 // expression to the type we need, then let the regular type
3924 // conversion machinery handle it.
3925 expr = TypeConvertExpr(exprs[i], elementType, "initializer list");
3926 if (expr == NULL) {
3927 AssertPos(pos, m->errorCount > 0);
3928 return std::pair<llvm::Constant *, bool>(NULL, false);
3929 }
3930 // Re-establish const-ness if possible
3931 expr = ::Optimize(expr);
3932 }
3933 std::pair<llvm::Constant *, bool> cPair;
3934 if (isStorageType)
3935 cPair = expr->GetStorageConstant(elementType);
3936 else
3937 cPair = expr->GetConstant(elementType);
3938 llvm::Constant *c = cPair.first;
3939 if (c == NULL)
3940 // If this list element couldn't convert to the right constant
3941 // type for the corresponding collection member, then give up.
3942 return std::pair<llvm::Constant *, bool>(NULL, false);
3943 isNotValidForMultiTargetGlobal = isNotValidForMultiTargetGlobal || cPair.second;
3944 cv.push_back(c);
3945 }
3946
3947 // If there are too few, then treat missing ones as if they were zero
3948 if (isVaryingInit == false) {
3949 for (int i = (int)exprs.size(); i < collectionType->GetElementCount(); ++i) {
3950 const Type *elementType = collectionType->GetElementType(i);
3951 if (elementType == NULL) {
3952 AssertPos(pos, m->errorCount > 0);
3953 return std::pair<llvm::Constant *, bool>(NULL, false);
3954 }
3955 llvm::Type *llvmType = elementType->LLVMType(g->ctx);
3956 if (llvmType == NULL) {
3957 AssertPos(pos, m->errorCount > 0);
3958 return std::pair<llvm::Constant *, bool>(NULL, false);
3959 }
3960
3961 llvm::Constant *c = llvm::Constant::getNullValue(llvmType);
3962 cv.push_back(c);
3963 }
3964 }
3965
3966 if (CastType<StructType>(type) != NULL) {
3967 llvm::StructType *llvmStructType = llvm::dyn_cast<llvm::StructType>(collectionType->LLVMType(g->ctx));
3968 AssertPos(pos, llvmStructType != NULL);
3969 return std::pair<llvm::Constant *, bool>(llvm::ConstantStruct::get(llvmStructType, cv),
3970 isNotValidForMultiTargetGlobal);
3971 } else {
3972 llvm::Type *lt = type->LLVMType(g->ctx);
3973 llvm::ArrayType *lat = llvm::dyn_cast<llvm::ArrayType>(lt);
3974 if (lat != NULL)
3975 return std::pair<llvm::Constant *, bool>(llvm::ConstantArray::get(lat, cv), isNotValidForMultiTargetGlobal);
3976 else if (type->IsVaryingType()) {
3977 // uniform short vector type
3978 llvm::VectorType *lvt = llvm::dyn_cast<llvm::VectorType>(lt);
3979 AssertPos(pos, lvt != NULL);
3980 int vectorWidth = g->target->getVectorWidth();
3981
3982 while ((cv.size() % vectorWidth) != 0) {
3983 cv.push_back(llvm::UndefValue::get(lvt->getElementType()));
3984 }
3985
3986 return std::pair<llvm::Constant *, bool>(llvm::ConstantVector::get(cv), isNotValidForMultiTargetGlobal);
3987 } else {
3988 // uniform short vector type
3989 AssertPos(pos, type->IsUniformType() && CastType<VectorType>(type) != NULL);
3990
3991 llvm::VectorType *lvt = llvm::dyn_cast<llvm::VectorType>(lt);
3992 AssertPos(pos, lvt != NULL);
3993
3994 // Uniform short vectors are stored as vectors of length
3995 // rounded up to a power of 2 bits in size but not less then 128 bit.
3996 // So we add additional undef values here until we get the right size.
3997 const VectorType *vt = CastType<VectorType>(type);
3998 int vectorWidth = vt->getVectorMemoryCount();
3999
4000 while ((cv.size() % vectorWidth) != 0) {
4001 cv.push_back(llvm::UndefValue::get(lvt->getElementType()));
4002 }
4003
4004 return std::pair<llvm::Constant *, bool>(llvm::ConstantVector::get(cv), isNotValidForMultiTargetGlobal);
4005 }
4006 }
4007 return std::pair<llvm::Constant *, bool>(NULL, false);
4008 }
4009
GetStorageConstant(const Type * type) const4010 std::pair<llvm::Constant *, bool> ExprList::GetStorageConstant(const Type *type) const {
4011 return lGetExprListConstant(type, this, true);
4012 }
GetConstant(const Type * type) const4013 std::pair<llvm::Constant *, bool> ExprList::GetConstant(const Type *type) const {
4014 return lGetExprListConstant(type, this, false);
4015 }
4016
EstimateCost() const4017 int ExprList::EstimateCost() const { return 0; }
4018
Print() const4019 void ExprList::Print() const {
4020 printf("expr list (");
4021 for (unsigned int i = 0; i < exprs.size(); ++i) {
4022 if (exprs[i] != NULL)
4023 exprs[i]->Print();
4024 printf("%s", (i == exprs.size() - 1) ? ")" : ", ");
4025 }
4026 pos.Print();
4027 }
4028
4029 ///////////////////////////////////////////////////////////////////////////
4030 // IndexExpr
4031
IndexExpr(Expr * a,Expr * i,SourcePos p)4032 IndexExpr::IndexExpr(Expr *a, Expr *i, SourcePos p) : Expr(p, IndexExprID) {
4033 baseExpr = a;
4034 index = i;
4035 type = lvalueType = NULL;
4036 }
4037
4038 /** When computing pointer values, we need to apply a per-lane offset when
4039 we have a varying pointer that is itself indexing into varying data.
4040 Consdier the following ispc code:
4041
4042 uniform float u[] = ...;
4043 float v[] = ...;
4044 int index = ...;
4045 float a = u[index];
4046 float b = v[index];
4047
4048 To compute the varying pointer that holds the addresses to load from
4049 for u[index], we basically just need to multiply index element-wise by
4050 sizeof(float) before doing the memory load. For v[index], we need to
4051 do the same scaling but also need to add per-lane offsets <0,
4052 sizeof(float), 2*sizeof(float), ...> so that the i'th lane loads the
4053 i'th of the varying values at its index value.
4054
4055 This function handles figuring out when this additional offset is
4056 needed and then incorporates it in the varying pointer value.
4057 */
lAddVaryingOffsetsIfNeeded(FunctionEmitContext * ctx,llvm::Value * ptr,const Type * ptrRefType)4058 static llvm::Value *lAddVaryingOffsetsIfNeeded(FunctionEmitContext *ctx, llvm::Value *ptr, const Type *ptrRefType) {
4059 if (CastType<ReferenceType>(ptrRefType) != NULL)
4060 // References are uniform pointers, so no offsetting is needed
4061 return ptr;
4062
4063 const PointerType *ptrType = CastType<PointerType>(ptrRefType);
4064 Assert(ptrType != NULL);
4065 if (ptrType->IsUniformType() || ptrType->IsSlice())
4066 return ptr;
4067
4068 const Type *baseType = ptrType->GetBaseType();
4069 if (baseType->IsVaryingType() == false)
4070 return ptr;
4071
4072 // must be indexing into varying atomic, enum, or pointer types
4073 if (Type::IsBasicType(baseType) == false)
4074 return ptr;
4075
4076 // Onward: compute the per lane offsets.
4077 llvm::Value *varyingOffsets = ctx->ProgramIndexVector();
4078
4079 // And finally add the per-lane offsets. Note that we lie to the GEP
4080 // call and tell it that the pointers are to uniform elements and not
4081 // varying elements, so that the offsets in terms of (0,1,2,...) will
4082 // end up turning into the correct step in bytes...
4083 const Type *uniformElementType = baseType->GetAsUniformType();
4084 const Type *ptrUnifType = PointerType::GetVarying(uniformElementType);
4085 return ctx->GetElementPtrInst(ptr, varyingOffsets, ptrUnifType);
4086 }
4087
4088 /** Check to see if the given type is an array of or pointer to a varying
4089 struct type that in turn has a member with bound 'uniform' variability.
4090 Issue an error and return true if such a member is found.
4091 */
lVaryingStructHasUniformMember(const Type * type,SourcePos pos)4092 static bool lVaryingStructHasUniformMember(const Type *type, SourcePos pos) {
4093 if (CastType<VectorType>(type) != NULL || CastType<ReferenceType>(type) != NULL)
4094 return false;
4095
4096 const StructType *st = CastType<StructType>(type);
4097 if (st == NULL) {
4098 const ArrayType *at = CastType<ArrayType>(type);
4099 if (at != NULL)
4100 st = CastType<StructType>(at->GetElementType());
4101 else {
4102 const PointerType *pt = CastType<PointerType>(type);
4103 if (pt == NULL)
4104 return false;
4105
4106 st = CastType<StructType>(pt->GetBaseType());
4107 }
4108
4109 if (st == NULL)
4110 return false;
4111 }
4112
4113 if (st->IsVaryingType() == false)
4114 return false;
4115
4116 for (int i = 0; i < st->GetElementCount(); ++i) {
4117 const Type *eltType = st->GetElementType(i);
4118 if (eltType == NULL) {
4119 AssertPos(pos, m->errorCount > 0);
4120 continue;
4121 }
4122
4123 if (CastType<StructType>(eltType) != NULL) {
4124 // We know that the enclosing struct is varying at this point,
4125 // so push that down to the enclosed struct before makign the
4126 // recursive call.
4127 eltType = eltType->GetAsVaryingType();
4128 if (lVaryingStructHasUniformMember(eltType, pos))
4129 return true;
4130 } else if (eltType->IsUniformType()) {
4131 Error(pos,
4132 "Gather operation is impossible due to the presence of "
4133 "struct member \"%s\" with uniform type \"%s\" in the "
4134 "varying struct type \"%s\".",
4135 st->GetElementName(i).c_str(), eltType->GetString().c_str(), st->GetString().c_str());
4136 return true;
4137 }
4138 }
4139
4140 return false;
4141 }
4142
GetValue(FunctionEmitContext * ctx) const4143 llvm::Value *IndexExpr::GetValue(FunctionEmitContext *ctx) const {
4144 const Type *indexType, *returnType;
4145 if (baseExpr == NULL || index == NULL || ((indexType = index->GetType()) == NULL) ||
4146 ((returnType = GetType()) == NULL)) {
4147 AssertPos(pos, m->errorCount > 0);
4148 return NULL;
4149 }
4150
4151 // If this is going to be a gather, make sure that the varying return
4152 // type can represent the result (i.e. that we don't have a bound
4153 // 'uniform' member in a varying struct...)
4154 if (indexType->IsVaryingType() && lVaryingStructHasUniformMember(returnType, pos))
4155 return NULL;
4156
4157 ctx->SetDebugPos(pos);
4158
4159 llvm::Value *ptr = GetLValue(ctx);
4160 llvm::Value *mask = NULL;
4161 const Type *lvType = GetLValueType();
4162 if (ptr == NULL) {
4163 // We may be indexing into a temporary that hasn't hit memory, so
4164 // get the full value and stuff it into temporary alloca'd space so
4165 // that we can index from there...
4166 const Type *baseExprType = baseExpr->GetType();
4167 llvm::Value *val = baseExpr->GetValue(ctx);
4168 if (baseExprType == NULL || val == NULL) {
4169 AssertPos(pos, m->errorCount > 0);
4170 return NULL;
4171 }
4172 ctx->SetDebugPos(pos);
4173 llvm::Value *tmpPtr = ctx->AllocaInst(baseExprType, "array_tmp");
4174 ctx->StoreInst(val, tmpPtr, baseExprType, baseExprType->IsUniformType());
4175
4176 // Get a pointer type to the underlying elements
4177 const SequentialType *st = CastType<SequentialType>(baseExprType);
4178 if (st == NULL) {
4179 Assert(m->errorCount > 0);
4180 return NULL;
4181 }
4182 lvType = PointerType::GetUniform(st->GetElementType());
4183
4184 // And do the indexing calculation into the temporary array in memory
4185 ptr = ctx->GetElementPtrInst(tmpPtr, LLVMInt32(0), index->GetValue(ctx), PointerType::GetUniform(baseExprType));
4186 ptr = lAddVaryingOffsetsIfNeeded(ctx, ptr, lvType);
4187
4188 mask = LLVMMaskAllOn;
4189 } else {
4190 Symbol *baseSym = GetBaseSymbol();
4191 if (llvm::dyn_cast<FunctionCallExpr>(baseExpr) == NULL && llvm::dyn_cast<BinaryExpr>(baseExpr) == NULL) {
4192 // Don't check if we're doing a function call or pointer arith
4193 AssertPos(pos, baseSym != NULL);
4194 }
4195 mask = lMaskForSymbol(baseSym, ctx);
4196 }
4197
4198 ctx->SetDebugPos(pos);
4199 return ctx->LoadInst(ptr, mask, lvType);
4200 }
4201
GetType() const4202 const Type *IndexExpr::GetType() const {
4203 if (type != NULL)
4204 return type;
4205
4206 const Type *baseExprType, *indexType;
4207 if (!baseExpr || !index || ((baseExprType = baseExpr->GetType()) == NULL) ||
4208 ((indexType = index->GetType()) == NULL))
4209 return NULL;
4210
4211 const Type *elementType = NULL;
4212 const PointerType *pointerType = CastType<PointerType>(baseExprType);
4213 if (pointerType != NULL)
4214 // ptr[index] -> type that the pointer points to
4215 elementType = pointerType->GetBaseType();
4216 else if (const SequentialType *sequentialType = CastType<SequentialType>(baseExprType->GetReferenceTarget()))
4217 // sequential type[index] -> element type of the sequential type
4218 elementType = sequentialType->GetElementType();
4219 else
4220 // Not an expression that can be indexed into. Will result in error.
4221 return NULL;
4222
4223 // If we're indexing into a sequence of SOA types, the result type is
4224 // actually the underlying type, as a uniform or varying. Get the
4225 // uniform variant of it for starters, then below we'll make it varying
4226 // if the index is varying.
4227 // (If we ever provide a way to index into SOA types and get an entire
4228 // SOA'd struct out of the array, then we won't want to do this in that
4229 // case..)
4230 if (elementType->IsSOAType())
4231 elementType = elementType->GetAsUniformType();
4232
4233 // If either the index is varying or we're indexing into a varying
4234 // pointer, then the result type is the varying variant of the indexed
4235 // type.
4236 if (indexType->IsUniformType() && (pointerType == NULL || pointerType->IsUniformType()))
4237 type = elementType;
4238 else
4239 type = elementType->GetAsVaryingType();
4240
4241 return type;
4242 }
4243
GetBaseSymbol() const4244 Symbol *IndexExpr::GetBaseSymbol() const { return baseExpr ? baseExpr->GetBaseSymbol() : NULL; }
4245
4246 /** Utility routine that takes a regualr pointer (either uniform or
4247 varying) and returns a slice pointer with zero offsets.
4248 */
lConvertToSlicePointer(FunctionEmitContext * ctx,llvm::Value * ptr,const PointerType * slicePtrType)4249 static llvm::Value *lConvertToSlicePointer(FunctionEmitContext *ctx, llvm::Value *ptr,
4250 const PointerType *slicePtrType) {
4251 llvm::Type *llvmSlicePtrType = slicePtrType->LLVMType(g->ctx);
4252 llvm::StructType *sliceStructType = llvm::dyn_cast<llvm::StructType>(llvmSlicePtrType);
4253 Assert(sliceStructType != NULL && sliceStructType->getElementType(0) == ptr->getType());
4254
4255 // Get a null-initialized struct to take care of having zeros for the
4256 // offsets
4257 llvm::Value *result = llvm::Constant::getNullValue(sliceStructType);
4258 // And replace the pointer in the struct with the given pointer
4259 return ctx->InsertInst(result, ptr, 0, llvm::Twine(ptr->getName()) + "_slice");
4260 }
4261
4262 /** If the given array index is a compile time constant, check to see if it
4263 value/values don't go past the end of the array; issue a warning if
4264 so.
4265 */
lCheckIndicesVersusBounds(const Type * baseExprType,Expr * index)4266 static void lCheckIndicesVersusBounds(const Type *baseExprType, Expr *index) {
4267 const SequentialType *seqType = CastType<SequentialType>(baseExprType);
4268 if (seqType == NULL)
4269 return;
4270
4271 int nElements = seqType->GetElementCount();
4272 if (nElements == 0)
4273 // Unsized array...
4274 return;
4275
4276 // If it's an array of soa<> items, then the number of elements to
4277 // worry about w.r.t. index values is the product of the array size and
4278 // the soa width.
4279 int soaWidth = seqType->GetElementType()->GetSOAWidth();
4280 if (soaWidth > 0)
4281 nElements *= soaWidth;
4282
4283 ConstExpr *ce = llvm::dyn_cast<ConstExpr>(index);
4284 if (ce == NULL)
4285 return;
4286
4287 int32_t indices[ISPC_MAX_NVEC];
4288 int count = ce->GetValues(indices);
4289 for (int i = 0; i < count; ++i) {
4290 if (indices[i] < 0 || indices[i] >= nElements)
4291 Warning(index->pos,
4292 "Array index \"%d\" may be out of bounds for %d "
4293 "element array.",
4294 indices[i], nElements);
4295 }
4296 }
4297
4298 /** Converts the given pointer value to a slice pointer if the pointer
4299 points to SOA'ed data.
4300 */
lConvertPtrToSliceIfNeeded(FunctionEmitContext * ctx,llvm::Value * ptr,const Type ** type)4301 static llvm::Value *lConvertPtrToSliceIfNeeded(FunctionEmitContext *ctx, llvm::Value *ptr, const Type **type) {
4302 Assert(*type != NULL);
4303 const PointerType *ptrType = CastType<PointerType>(*type);
4304 Assert(ptrType != NULL);
4305 bool convertToSlice = (ptrType->GetBaseType()->IsSOAType() && ptrType->IsSlice() == false);
4306 if (convertToSlice == false)
4307 return ptr;
4308
4309 *type = ptrType->GetAsSlice();
4310 return lConvertToSlicePointer(ctx, ptr, ptrType->GetAsSlice());
4311 }
4312
GetLValue(FunctionEmitContext * ctx) const4313 llvm::Value *IndexExpr::GetLValue(FunctionEmitContext *ctx) const {
4314 const Type *baseExprType;
4315 if (baseExpr == NULL || index == NULL || ((baseExprType = baseExpr->GetType()) == NULL)) {
4316 AssertPos(pos, m->errorCount > 0);
4317 return NULL;
4318 }
4319
4320 ctx->SetDebugPos(pos);
4321 llvm::Value *indexValue = index->GetValue(ctx);
4322 if (indexValue == NULL) {
4323 AssertPos(pos, m->errorCount > 0);
4324 return NULL;
4325 }
4326
4327 ctx->SetDebugPos(pos);
4328 if (CastType<PointerType>(baseExprType) != NULL) {
4329 // We're indexing off of a pointer
4330 llvm::Value *basePtrValue = baseExpr->GetValue(ctx);
4331 if (basePtrValue == NULL) {
4332 AssertPos(pos, m->errorCount > 0);
4333 return NULL;
4334 }
4335 ctx->SetDebugPos(pos);
4336
4337 // Convert to a slice pointer if we're indexing into SOA data
4338 basePtrValue = lConvertPtrToSliceIfNeeded(ctx, basePtrValue, &baseExprType);
4339
4340 llvm::Value *ptr = ctx->GetElementPtrInst(basePtrValue, indexValue, baseExprType,
4341 llvm::Twine(basePtrValue->getName()) + "_offset");
4342 return lAddVaryingOffsetsIfNeeded(ctx, ptr, GetLValueType());
4343 }
4344
4345 // Not a pointer: we must be indexing an array or vector (and possibly
4346 // a reference thereuponfore.)
4347 llvm::Value *basePtr = NULL;
4348 const PointerType *basePtrType = NULL;
4349 if (CastType<ArrayType>(baseExprType) || CastType<VectorType>(baseExprType)) {
4350 basePtr = baseExpr->GetLValue(ctx);
4351 basePtrType = CastType<PointerType>(baseExpr->GetLValueType());
4352 if (baseExpr->GetLValueType())
4353 AssertPos(pos, basePtrType != NULL);
4354 } else {
4355 baseExprType = baseExprType->GetReferenceTarget();
4356 AssertPos(pos, CastType<ArrayType>(baseExprType) || CastType<VectorType>(baseExprType));
4357 basePtr = baseExpr->GetValue(ctx);
4358 basePtrType = PointerType::GetUniform(baseExprType);
4359 }
4360 if (!basePtr)
4361 return NULL;
4362
4363 // If possible, check the index value(s) against the size of the array
4364 lCheckIndicesVersusBounds(baseExprType, index);
4365
4366 // Convert to a slice pointer if indexing into SOA data
4367 basePtr = lConvertPtrToSliceIfNeeded(ctx, basePtr, (const Type **)&basePtrType);
4368
4369 ctx->SetDebugPos(pos);
4370
4371 // And do the actual indexing calculation..
4372 llvm::Value *ptr = ctx->GetElementPtrInst(basePtr, LLVMInt32(0), indexValue, basePtrType,
4373 llvm::Twine(basePtr->getName()) + "_offset");
4374 return lAddVaryingOffsetsIfNeeded(ctx, ptr, GetLValueType());
4375 }
4376
GetLValueType() const4377 const Type *IndexExpr::GetLValueType() const {
4378 if (lvalueType != NULL)
4379 return lvalueType;
4380
4381 const Type *baseExprType, *baseExprLValueType, *indexType;
4382 if (baseExpr == NULL || index == NULL || ((baseExprType = baseExpr->GetType()) == NULL) ||
4383 ((baseExprLValueType = baseExpr->GetLValueType()) == NULL) || ((indexType = index->GetType()) == NULL))
4384 return NULL;
4385
4386 // regularize to a PointerType
4387 if (CastType<ReferenceType>(baseExprLValueType) != NULL) {
4388 const Type *refTarget = baseExprLValueType->GetReferenceTarget();
4389 baseExprLValueType = PointerType::GetUniform(refTarget);
4390 }
4391 AssertPos(pos, CastType<PointerType>(baseExprLValueType) != NULL);
4392
4393 // Find the type of thing that we're indexing into
4394 const Type *elementType;
4395 const SequentialType *st = CastType<SequentialType>(baseExprLValueType->GetBaseType());
4396 if (st != NULL)
4397 elementType = st->GetElementType();
4398 else {
4399 const PointerType *pt = CastType<PointerType>(baseExprLValueType->GetBaseType());
4400 // This assertion seems overly strict.
4401 // Why does it need to be a pointer to a pointer?
4402 // AssertPos(pos, pt != NULL);
4403
4404 if (pt != NULL) {
4405 elementType = pt->GetBaseType();
4406 } else {
4407 elementType = baseExprLValueType->GetBaseType();
4408 }
4409 }
4410
4411 // Are we indexing into a varying type, or are we indexing with a
4412 // varying pointer?
4413 bool baseVarying;
4414 if (CastType<PointerType>(baseExprType) != NULL)
4415 baseVarying = baseExprType->IsVaryingType();
4416 else
4417 baseVarying = baseExprLValueType->IsVaryingType();
4418
4419 // The return type is uniform iff. the base is a uniform pointer / a
4420 // collection of uniform typed elements and the index is uniform.
4421 if (baseVarying == false && indexType->IsUniformType())
4422 lvalueType = PointerType::GetUniform(elementType);
4423 else
4424 lvalueType = PointerType::GetVarying(elementType);
4425
4426 // Finally, if we're indexing into an SOA type, then the resulting
4427 // pointer must (currently) be a slice pointer; we don't allow indexing
4428 // the soa-width-wide structs directly.
4429 if (elementType->IsSOAType())
4430 lvalueType = lvalueType->GetAsSlice();
4431
4432 return lvalueType;
4433 }
4434
Optimize()4435 Expr *IndexExpr::Optimize() {
4436 if (baseExpr == NULL || index == NULL)
4437 return NULL;
4438 return this;
4439 }
4440
TypeCheck()4441 Expr *IndexExpr::TypeCheck() {
4442 const Type *indexType;
4443 if (baseExpr == NULL || index == NULL || ((indexType = index->GetType()) == NULL)) {
4444 AssertPos(pos, m->errorCount > 0);
4445 return NULL;
4446 }
4447
4448 const Type *baseExprType = baseExpr->GetType();
4449 if (baseExprType == NULL) {
4450 AssertPos(pos, m->errorCount > 0);
4451 return NULL;
4452 }
4453
4454 if (!CastType<SequentialType>(baseExprType->GetReferenceTarget())) {
4455 if (const PointerType *pt = CastType<PointerType>(baseExprType)) {
4456 if (pt->GetBaseType()->IsVoidType()) {
4457 Error(pos, "Illegal to dereference void pointer type \"%s\".", baseExprType->GetString().c_str());
4458 return NULL;
4459 }
4460 } else {
4461 Error(pos,
4462 "Trying to index into non-array, vector, or pointer "
4463 "type \"%s\".",
4464 baseExprType->GetString().c_str());
4465 return NULL;
4466 }
4467 }
4468
4469 bool isUniform = (index->GetType()->IsUniformType() && !g->opt.disableUniformMemoryOptimizations);
4470
4471 if (!isUniform) {
4472 // Unless we have an explicit 64-bit index and are compiling to a
4473 // 64-bit target with 64-bit addressing, convert the index to an int32
4474 // type.
4475 // The range of varying index is limited to [0,2^31) as a result.
4476 if (!(Type::EqualIgnoringConst(indexType->GetAsUniformType(), AtomicType::UniformUInt64) ||
4477 Type::EqualIgnoringConst(indexType->GetAsUniformType(), AtomicType::UniformInt64)) ||
4478 g->target->is32Bit() || g->opt.force32BitAddressing) {
4479 const Type *indexType = AtomicType::VaryingInt32;
4480 index = TypeConvertExpr(index, indexType, "array index");
4481 if (index == NULL)
4482 return NULL;
4483 }
4484 } else { // isUniform
4485 // For 32-bit target:
4486 // force the index to 32 bit.
4487 // For 64-bit target:
4488 // We don't want to limit the index range.
4489 // We sxt/zxt the index to 64 bit right here because
4490 // LLVM doesn't distinguish unsigned from signed (both are i32)
4491 //
4492 // However, the index can be still truncated to signed int32 if
4493 // the index type is 64 bit and --addressing=32.
4494 bool force_32bit =
4495 g->target->is32Bit() || (g->opt.force32BitAddressing &&
4496 Type::EqualIgnoringConst(indexType->GetAsUniformType(), AtomicType::UniformInt64));
4497 const Type *indexType = force_32bit ? AtomicType::UniformInt32 : AtomicType::UniformInt64;
4498 index = TypeConvertExpr(index, indexType, "array index");
4499 if (index == NULL)
4500 return NULL;
4501 }
4502
4503 return this;
4504 }
4505
EstimateCost() const4506 int IndexExpr::EstimateCost() const {
4507 if (index == NULL || baseExpr == NULL)
4508 return 0;
4509
4510 const Type *indexType = index->GetType();
4511 const Type *baseExprType = baseExpr->GetType();
4512
4513 if ((indexType != NULL && indexType->IsVaryingType()) ||
4514 (CastType<PointerType>(baseExprType) != NULL && baseExprType->IsVaryingType()))
4515 // be pessimistic; some of these will later turn out to be vector
4516 // loads/stores, but it's too early for us to know that here.
4517 return COST_GATHER;
4518 else
4519 return COST_LOAD;
4520 }
4521
Print() const4522 void IndexExpr::Print() const {
4523 if (!baseExpr || !index || !GetType())
4524 return;
4525
4526 printf("[%s] index ", GetType()->GetString().c_str());
4527 baseExpr->Print();
4528 printf("[");
4529 index->Print();
4530 printf("]");
4531 pos.Print();
4532 }
4533
4534 ///////////////////////////////////////////////////////////////////////////
4535 // MemberExpr
4536
4537 /** Map one character ids to vector element numbers. Allow a few different
4538 conventions--xyzw, rgba, uv.
4539 */
lIdentifierToVectorElement(char id)4540 static int lIdentifierToVectorElement(char id) {
4541 switch (id) {
4542 case 'x':
4543 case 'r':
4544 case 'u':
4545 return 0;
4546 case 'y':
4547 case 'g':
4548 case 'v':
4549 return 1;
4550 case 'z':
4551 case 'b':
4552 return 2;
4553 case 'w':
4554 case 'a':
4555 return 3;
4556 default:
4557 return -1;
4558 }
4559 }
4560
4561 //////////////////////////////////////////////////
4562 // StructMemberExpr
4563
4564 class StructMemberExpr : public MemberExpr {
4565 public:
4566 StructMemberExpr(Expr *e, const char *id, SourcePos p, SourcePos idpos, bool derefLValue);
4567
classof(StructMemberExpr const *)4568 static inline bool classof(StructMemberExpr const *) { return true; }
classof(ASTNode const * N)4569 static inline bool classof(ASTNode const *N) { return N->getValueID() == StructMemberExprID; }
4570
4571 const Type *GetType() const;
4572 const Type *GetLValueType() const;
4573 int getElementNumber() const;
4574 const Type *getElementType() const;
4575
4576 private:
4577 const StructType *getStructType() const;
4578 };
4579
StructMemberExpr(Expr * e,const char * id,SourcePos p,SourcePos idpos,bool derefLValue)4580 StructMemberExpr::StructMemberExpr(Expr *e, const char *id, SourcePos p, SourcePos idpos, bool derefLValue)
4581 : MemberExpr(e, id, p, idpos, derefLValue, StructMemberExprID) {}
4582
GetType() const4583 const Type *StructMemberExpr::GetType() const {
4584 if (type != NULL)
4585 return type;
4586
4587 // It's a struct, and the result type is the element type, possibly
4588 // promoted to varying if the struct type / lvalue is varying.
4589 const Type *exprType, *lvalueType;
4590 const StructType *structType;
4591 if (expr == NULL || ((exprType = expr->GetType()) == NULL) || ((structType = getStructType()) == NULL) ||
4592 ((lvalueType = GetLValueType()) == NULL)) {
4593 AssertPos(pos, m->errorCount > 0);
4594 return NULL;
4595 }
4596
4597 const Type *elementType = structType->GetElementType(identifier);
4598 if (elementType == NULL) {
4599 Error(identifierPos, "Element name \"%s\" not present in struct type \"%s\".%s", identifier.c_str(),
4600 structType->GetString().c_str(), getCandidateNearMatches().c_str());
4601 return NULL;
4602 }
4603 AssertPos(pos, Type::Equal(lvalueType->GetBaseType(), elementType));
4604
4605 bool isSlice = (CastType<PointerType>(lvalueType) && CastType<PointerType>(lvalueType)->IsSlice());
4606 if (isSlice) {
4607 // FIXME: not true if we allow bound unif/varying for soa<>
4608 // structs?...
4609 AssertPos(pos, elementType->IsSOAType());
4610
4611 // If we're accessing a member of an soa structure via a uniform
4612 // slice pointer, then the result type is the uniform variant of
4613 // the element type.
4614 if (lvalueType->IsUniformType())
4615 elementType = elementType->GetAsUniformType();
4616 }
4617
4618 if (lvalueType->IsVaryingType())
4619 // If the expression we're getting the member of has an lvalue that
4620 // is a varying pointer type (be it slice or non-slice), then the
4621 // result type must be the varying version of the element type.
4622 elementType = elementType->GetAsVaryingType();
4623
4624 type = elementType;
4625 return type;
4626 }
4627
GetLValueType() const4628 const Type *StructMemberExpr::GetLValueType() const {
4629 if (lvalueType != NULL)
4630 return lvalueType;
4631
4632 if (expr == NULL) {
4633 AssertPos(pos, m->errorCount > 0);
4634 return NULL;
4635 }
4636
4637 const Type *exprLValueType = dereferenceExpr ? expr->GetType() : expr->GetLValueType();
4638 if (exprLValueType == NULL) {
4639 AssertPos(pos, m->errorCount > 0);
4640 return NULL;
4641 }
4642
4643 // The pointer type is varying if the lvalue type of the expression is
4644 // varying (and otherwise uniform)
4645 const PointerType *ptrType = (exprLValueType->IsUniformType() || CastType<ReferenceType>(exprLValueType) != NULL)
4646 ? PointerType::GetUniform(getElementType())
4647 : PointerType::GetVarying(getElementType());
4648
4649 // If struct pointer is a slice pointer, the resulting member pointer
4650 // needs to be a frozen slice pointer--i.e. any further indexing with
4651 // the result shouldn't modify the minor slice offset, but it should be
4652 // left unchanged until we get to a leaf SOA value.
4653 if (CastType<PointerType>(exprLValueType) && CastType<PointerType>(exprLValueType)->IsSlice())
4654 ptrType = ptrType->GetAsFrozenSlice();
4655
4656 lvalueType = ptrType;
4657 return lvalueType;
4658 }
4659
getElementNumber() const4660 int StructMemberExpr::getElementNumber() const {
4661 const StructType *structType = getStructType();
4662 if (structType == NULL)
4663 return -1;
4664
4665 int elementNumber = structType->GetElementNumber(identifier);
4666 if (elementNumber == -1)
4667 Error(identifierPos, "Element name \"%s\" not present in struct type \"%s\".%s", identifier.c_str(),
4668 structType->GetString().c_str(), getCandidateNearMatches().c_str());
4669
4670 return elementNumber;
4671 }
4672
getElementType() const4673 const Type *StructMemberExpr::getElementType() const {
4674 const StructType *structType = getStructType();
4675 if (structType == NULL)
4676 return NULL;
4677
4678 return structType->GetElementType(identifier);
4679 }
4680
4681 /** Returns the type of the underlying struct that we're returning a member
4682 of. */
getStructType() const4683 const StructType *StructMemberExpr::getStructType() const {
4684 const Type *type = dereferenceExpr ? expr->GetType() : expr->GetLValueType();
4685 if (type == NULL)
4686 return NULL;
4687
4688 const Type *structType;
4689 const ReferenceType *rt = CastType<ReferenceType>(type);
4690 if (rt != NULL)
4691 structType = rt->GetReferenceTarget();
4692 else {
4693 const PointerType *pt = CastType<PointerType>(type);
4694 AssertPos(pos, pt != NULL);
4695 structType = pt->GetBaseType();
4696 }
4697
4698 const StructType *ret = CastType<StructType>(structType);
4699 AssertPos(pos, ret != NULL);
4700 return ret;
4701 }
4702
4703 //////////////////////////////////////////////////
4704 // VectorMemberExpr
4705
4706 class VectorMemberExpr : public MemberExpr {
4707 public:
4708 VectorMemberExpr(Expr *e, const char *id, SourcePos p, SourcePos idpos, bool derefLValue);
4709
classof(VectorMemberExpr const *)4710 static inline bool classof(VectorMemberExpr const *) { return true; }
classof(ASTNode const * N)4711 static inline bool classof(ASTNode const *N) { return N->getValueID() == VectorMemberExprID; }
4712
4713 llvm::Value *GetValue(FunctionEmitContext *ctx) const;
4714 llvm::Value *GetLValue(FunctionEmitContext *ctx) const;
4715 const Type *GetType() const;
4716 const Type *GetLValueType() const;
4717
4718 int getElementNumber() const;
4719 const Type *getElementType() const;
4720
4721 private:
4722 const VectorType *exprVectorType;
4723 const VectorType *memberType;
4724 };
4725
VectorMemberExpr(Expr * e,const char * id,SourcePos p,SourcePos idpos,bool derefLValue)4726 VectorMemberExpr::VectorMemberExpr(Expr *e, const char *id, SourcePos p, SourcePos idpos, bool derefLValue)
4727 : MemberExpr(e, id, p, idpos, derefLValue, VectorMemberExprID) {
4728 const Type *exprType = e->GetType();
4729 exprVectorType = CastType<VectorType>(exprType);
4730 if (exprVectorType == NULL) {
4731 const PointerType *pt = CastType<PointerType>(exprType);
4732 if (pt != NULL)
4733 exprVectorType = CastType<VectorType>(pt->GetBaseType());
4734 else {
4735 AssertPos(pos, CastType<ReferenceType>(exprType) != NULL);
4736 exprVectorType = CastType<VectorType>(exprType->GetReferenceTarget());
4737 }
4738 AssertPos(pos, exprVectorType != NULL);
4739 }
4740 memberType = new VectorType(exprVectorType->GetElementType(), identifier.length());
4741 }
4742
GetType() const4743 const Type *VectorMemberExpr::GetType() const {
4744 if (type != NULL)
4745 return type;
4746
4747 // For 1-element expressions, we have the base vector element
4748 // type. For n-element expressions, we have a shortvec type
4749 // with n > 1 elements. This can be changed when we get
4750 // type<1> -> type conversions.
4751 type = (identifier.length() == 1) ? (const Type *)exprVectorType->GetElementType() : (const Type *)memberType;
4752
4753 const Type *lvType = GetLValueType();
4754 if (lvType != NULL) {
4755 bool isSlice = (CastType<PointerType>(lvType) && CastType<PointerType>(lvType)->IsSlice());
4756 if (isSlice) {
4757 // CO AssertPos(pos, type->IsSOAType());
4758 if (lvType->IsUniformType())
4759 type = type->GetAsUniformType();
4760 }
4761
4762 if (lvType->IsVaryingType())
4763 type = type->GetAsVaryingType();
4764 }
4765
4766 return type;
4767 }
4768
GetLValue(FunctionEmitContext * ctx) const4769 llvm::Value *VectorMemberExpr::GetLValue(FunctionEmitContext *ctx) const {
4770 if (identifier.length() == 1) {
4771 return MemberExpr::GetLValue(ctx);
4772 } else {
4773 return NULL;
4774 }
4775 }
4776
GetLValueType() const4777 const Type *VectorMemberExpr::GetLValueType() const {
4778 if (lvalueType != NULL)
4779 return lvalueType;
4780
4781 if (identifier.length() == 1) {
4782 if (expr == NULL) {
4783 AssertPos(pos, m->errorCount > 0);
4784 return NULL;
4785 }
4786
4787 const Type *exprLValueType = dereferenceExpr ? expr->GetType() : expr->GetLValueType();
4788 if (exprLValueType == NULL)
4789 return NULL;
4790
4791 const VectorType *vt = NULL;
4792 if (CastType<ReferenceType>(exprLValueType) != NULL)
4793 vt = CastType<VectorType>(exprLValueType->GetReferenceTarget());
4794 else
4795 vt = CastType<VectorType>(exprLValueType->GetBaseType());
4796 AssertPos(pos, vt != NULL);
4797
4798 // we don't want to report that it's e.g. a pointer to a float<1>,
4799 // but a pointer to a float, etc.
4800 const Type *elementType = vt->GetElementType();
4801 if (CastType<ReferenceType>(exprLValueType) != NULL)
4802 lvalueType = new ReferenceType(elementType);
4803 else {
4804 const PointerType *ptrType = exprLValueType->IsUniformType() ? PointerType::GetUniform(elementType)
4805 : PointerType::GetVarying(elementType);
4806 // FIXME: replicated logic with structmemberexpr....
4807 if (CastType<PointerType>(exprLValueType) && CastType<PointerType>(exprLValueType)->IsSlice())
4808 ptrType = ptrType->GetAsFrozenSlice();
4809 lvalueType = ptrType;
4810 }
4811 }
4812
4813 return lvalueType;
4814 }
4815
GetValue(FunctionEmitContext * ctx) const4816 llvm::Value *VectorMemberExpr::GetValue(FunctionEmitContext *ctx) const {
4817 if (identifier.length() == 1) {
4818 return MemberExpr::GetValue(ctx);
4819 } else {
4820 std::vector<int> indices;
4821
4822 for (size_t i = 0; i < identifier.size(); ++i) {
4823 int idx = lIdentifierToVectorElement(identifier[i]);
4824 if (idx == -1)
4825 Error(pos, "Invalid swizzle character '%c' in swizzle \"%s\".", identifier[i], identifier.c_str());
4826
4827 indices.push_back(idx);
4828 }
4829
4830 llvm::Value *basePtr = NULL;
4831 const Type *basePtrType = NULL;
4832 if (dereferenceExpr) {
4833 basePtr = expr->GetValue(ctx);
4834 basePtrType = expr->GetType();
4835 } else {
4836 basePtr = expr->GetLValue(ctx);
4837 basePtrType = expr->GetLValueType();
4838 }
4839
4840 if (basePtr == NULL || basePtrType == NULL) {
4841 // Check that expression on the left side is a rvalue expression
4842 llvm::Value *exprValue = expr->GetValue(ctx);
4843 basePtr = ctx->AllocaInst(expr->GetType());
4844 basePtrType = PointerType::GetUniform(exprVectorType);
4845 if (basePtr == NULL || basePtrType == NULL) {
4846 AssertPos(pos, m->errorCount > 0);
4847 return NULL;
4848 }
4849 ctx->StoreInst(exprValue, basePtr, expr->GetType(), expr->GetType()->IsUniformType());
4850 }
4851
4852 // Allocate temporary memory to store the result
4853 llvm::Value *resultPtr = ctx->AllocaInst(memberType, "vector_tmp");
4854
4855 if (resultPtr == NULL) {
4856 AssertPos(pos, m->errorCount > 0);
4857 return NULL;
4858 }
4859
4860 // FIXME: we should be able to use the internal mask here according
4861 // to the same logic where it's used elsewhere
4862 llvm::Value *elementMask = ctx->GetFullMask();
4863
4864 const Type *elementPtrType = NULL;
4865 if (CastType<ReferenceType>(basePtrType) != NULL)
4866 elementPtrType = PointerType::GetUniform(basePtrType->GetReferenceTarget());
4867 else
4868 elementPtrType = basePtrType->IsUniformType() ? PointerType::GetUniform(exprVectorType->GetElementType())
4869 : PointerType::GetVarying(exprVectorType->GetElementType());
4870
4871 ctx->SetDebugPos(pos);
4872 for (size_t i = 0; i < identifier.size(); ++i) {
4873 char idStr[2] = {identifier[i], '\0'};
4874 llvm::Value *elementPtr =
4875 ctx->AddElementOffset(basePtr, indices[i], basePtrType, llvm::Twine(basePtr->getName()) + idStr);
4876 llvm::Value *elementValue = ctx->LoadInst(elementPtr, elementMask, elementPtrType);
4877
4878 llvm::Value *ptmp = ctx->AddElementOffset(resultPtr, i, NULL, llvm::Twine(resultPtr->getName()) + idStr);
4879 ctx->StoreInst(elementValue, ptmp, elementPtrType, expr->GetType()->IsUniformType());
4880 }
4881
4882 return ctx->LoadInst(resultPtr, memberType, llvm::Twine(basePtr->getName()) + "_swizzle");
4883 }
4884 }
4885
getElementNumber() const4886 int VectorMemberExpr::getElementNumber() const {
4887 int elementNumber = lIdentifierToVectorElement(identifier[0]);
4888 if (elementNumber == -1)
4889 Error(pos, "Vector element identifier \"%s\" unknown.", identifier.c_str());
4890 return elementNumber;
4891 }
4892
getElementType() const4893 const Type *VectorMemberExpr::getElementType() const { return memberType; }
4894
create(Expr * e,const char * id,SourcePos p,SourcePos idpos,bool derefLValue)4895 MemberExpr *MemberExpr::create(Expr *e, const char *id, SourcePos p, SourcePos idpos, bool derefLValue) {
4896 // FIXME: we need to call TypeCheck() here so that we can call
4897 // e->GetType() in the following. But really we just shouldn't try to
4898 // resolve this now but just have a generic MemberExpr type that
4899 // handles all cases so that this is unnecessary.
4900 e = ::TypeCheck(e);
4901
4902 const Type *exprType;
4903 if (e == NULL || (exprType = e->GetType()) == NULL)
4904 return NULL;
4905
4906 const ReferenceType *referenceType = CastType<ReferenceType>(exprType);
4907 if (referenceType != NULL) {
4908 e = new RefDerefExpr(e, e->pos);
4909 exprType = e->GetType();
4910 Assert(exprType != NULL);
4911 }
4912
4913 const PointerType *pointerType = CastType<PointerType>(exprType);
4914 if (pointerType != NULL)
4915 exprType = pointerType->GetBaseType();
4916
4917 if (derefLValue == true && pointerType == NULL) {
4918 const Type *targetType = exprType->GetReferenceTarget();
4919 if (CastType<StructType>(targetType) != NULL)
4920 Error(p,
4921 "Member operator \"->\" can't be applied to non-pointer "
4922 "type \"%s\". Did you mean to use \".\"?",
4923 exprType->GetString().c_str());
4924 else
4925 Error(p,
4926 "Member operator \"->\" can't be applied to non-struct "
4927 "pointer type \"%s\".",
4928 exprType->GetString().c_str());
4929 return NULL;
4930 }
4931 // For struct and short-vector, emit error if elements are accessed
4932 // incorrectly.
4933 if (derefLValue == false && pointerType != NULL &&
4934 ((CastType<StructType>(pointerType->GetBaseType()) != NULL) ||
4935 (CastType<VectorType>(pointerType->GetBaseType()) != NULL))) {
4936 Error(p,
4937 "Member operator \".\" can't be applied to pointer "
4938 "type \"%s\". Did you mean to use \"->\"?",
4939 exprType->GetString().c_str());
4940 return NULL;
4941 }
4942 if (CastType<StructType>(exprType) != NULL) {
4943 const StructType *st = CastType<StructType>(exprType);
4944 if (st->IsDefined()) {
4945 return new StructMemberExpr(e, id, p, idpos, derefLValue);
4946 } else {
4947 Error(p,
4948 "Member operator \"%s\" can't be applied to declared "
4949 "struct \"%s\" containing an undefined struct type.",
4950 derefLValue ? "->" : ".", exprType->GetString().c_str());
4951 return NULL;
4952 }
4953 } else if (CastType<VectorType>(exprType) != NULL)
4954 return new VectorMemberExpr(e, id, p, idpos, derefLValue);
4955 else if (CastType<UndefinedStructType>(exprType)) {
4956 Error(p,
4957 "Member operator \"%s\" can't be applied to declared "
4958 "but not defined struct type \"%s\".",
4959 derefLValue ? "->" : ".", exprType->GetString().c_str());
4960 return NULL;
4961 } else {
4962 Error(p,
4963 "Member operator \"%s\" can't be used with expression of "
4964 "\"%s\" type.",
4965 derefLValue ? "->" : ".", exprType->GetString().c_str());
4966 return NULL;
4967 }
4968 }
4969
MemberExpr(Expr * e,const char * id,SourcePos p,SourcePos idpos,bool derefLValue,unsigned scid)4970 MemberExpr::MemberExpr(Expr *e, const char *id, SourcePos p, SourcePos idpos, bool derefLValue, unsigned scid)
4971 : Expr(p, scid), identifierPos(idpos) {
4972 expr = e;
4973 identifier = id;
4974 dereferenceExpr = derefLValue;
4975 type = lvalueType = NULL;
4976 }
4977
GetValue(FunctionEmitContext * ctx) const4978 llvm::Value *MemberExpr::GetValue(FunctionEmitContext *ctx) const {
4979 if (!expr)
4980 return NULL;
4981
4982 llvm::Value *lvalue = GetLValue(ctx);
4983 const Type *lvalueType = GetLValueType();
4984
4985 llvm::Value *mask = NULL;
4986 if (lvalue == NULL) {
4987 if (m->errorCount > 0)
4988 return NULL;
4989
4990 // As in the array case, this may be a temporary that hasn't hit
4991 // memory; get the full value and stuff it into a temporary array
4992 // so that we can index from there...
4993 llvm::Value *val = expr->GetValue(ctx);
4994 if (!val) {
4995 AssertPos(pos, m->errorCount > 0);
4996 return NULL;
4997 }
4998 ctx->SetDebugPos(pos);
4999 const Type *exprType = expr->GetType();
5000 llvm::Value *ptr = ctx->AllocaInst(exprType, "struct_tmp");
5001 ctx->StoreInst(val, ptr, exprType, exprType->IsUniformType());
5002
5003 int elementNumber = getElementNumber();
5004 if (elementNumber == -1)
5005 return NULL;
5006
5007 lvalue = ctx->AddElementOffset(ptr, elementNumber, PointerType::GetUniform(exprType));
5008 lvalueType = PointerType::GetUniform(GetType());
5009 mask = LLVMMaskAllOn;
5010 } else {
5011 Symbol *baseSym = GetBaseSymbol();
5012 AssertPos(pos, baseSym != NULL);
5013 mask = lMaskForSymbol(baseSym, ctx);
5014 }
5015
5016 ctx->SetDebugPos(pos);
5017 std::string suffix = std::string("_") + identifier;
5018 return ctx->LoadInst(lvalue, mask, lvalueType, llvm::Twine(lvalue->getName()) + suffix);
5019 }
5020
GetType() const5021 const Type *MemberExpr::GetType() const { return NULL; }
5022
GetBaseSymbol() const5023 Symbol *MemberExpr::GetBaseSymbol() const { return expr ? expr->GetBaseSymbol() : NULL; }
5024
getElementNumber() const5025 int MemberExpr::getElementNumber() const { return -1; }
5026
GetLValue(FunctionEmitContext * ctx) const5027 llvm::Value *MemberExpr::GetLValue(FunctionEmitContext *ctx) const {
5028 const Type *exprType;
5029 if (!expr || ((exprType = expr->GetType()) == NULL))
5030 return NULL;
5031
5032 ctx->SetDebugPos(pos);
5033 llvm::Value *basePtr = dereferenceExpr ? expr->GetValue(ctx) : expr->GetLValue(ctx);
5034 if (!basePtr)
5035 return NULL;
5036
5037 int elementNumber = getElementNumber();
5038 if (elementNumber == -1)
5039 return NULL;
5040
5041 const Type *exprLValueType = dereferenceExpr ? exprType : expr->GetLValueType();
5042 ctx->SetDebugPos(pos);
5043 llvm::Value *ptr = ctx->AddElementOffset(basePtr, elementNumber, exprLValueType, basePtr->getName().str().c_str());
5044 if (ptr == NULL) {
5045 AssertPos(pos, m->errorCount > 0);
5046 return NULL;
5047 }
5048
5049 ptr = lAddVaryingOffsetsIfNeeded(ctx, ptr, GetLValueType());
5050
5051 return ptr;
5052 }
5053
TypeCheck()5054 Expr *MemberExpr::TypeCheck() { return expr ? this : NULL; }
5055
Optimize()5056 Expr *MemberExpr::Optimize() { return expr ? this : NULL; }
5057
EstimateCost() const5058 int MemberExpr::EstimateCost() const {
5059 const Type *lvalueType = GetLValueType();
5060 if (lvalueType != NULL && lvalueType->IsVaryingType())
5061 return COST_GATHER + COST_SIMPLE_ARITH_LOGIC_OP;
5062 else
5063 return COST_SIMPLE_ARITH_LOGIC_OP;
5064 }
5065
Print() const5066 void MemberExpr::Print() const {
5067 if (!expr || !GetType())
5068 return;
5069
5070 printf("[%s] member (", GetType()->GetString().c_str());
5071 expr->Print();
5072 printf(" . %s)", identifier.c_str());
5073 pos.Print();
5074 }
5075
5076 /** There is no structure member with the name we've got in "identifier".
5077 Use the approximate string matching routine to see if the identifier is
5078 a minor misspelling of one of the ones that is there.
5079 */
getCandidateNearMatches() const5080 std::string MemberExpr::getCandidateNearMatches() const {
5081 const StructType *structType = CastType<StructType>(expr->GetType());
5082 if (!structType)
5083 return "";
5084
5085 std::vector<std::string> elementNames;
5086 for (int i = 0; i < structType->GetElementCount(); ++i)
5087 elementNames.push_back(structType->GetElementName(i));
5088 std::vector<std::string> alternates = MatchStrings(identifier, elementNames);
5089 if (!alternates.size())
5090 return "";
5091
5092 std::string ret = " Did you mean ";
5093 for (unsigned int i = 0; i < alternates.size(); ++i) {
5094 ret += std::string("\"") + alternates[i] + std::string("\"");
5095 if (i < alternates.size() - 1)
5096 ret += ", or ";
5097 }
5098 ret += "?";
5099 return ret;
5100 }
5101
5102 ///////////////////////////////////////////////////////////////////////////
5103 // ConstExpr
5104
ConstExpr(const Type * t,int8_t i,SourcePos p)5105 ConstExpr::ConstExpr(const Type *t, int8_t i, SourcePos p) : Expr(p, ConstExprID) {
5106 type = t;
5107 type = type->GetAsConstType();
5108 AssertPos(pos, Type::Equal(type, AtomicType::UniformInt8->GetAsConstType()));
5109 int8Val[0] = i;
5110 }
5111
ConstExpr(const Type * t,int8_t * i,SourcePos p)5112 ConstExpr::ConstExpr(const Type *t, int8_t *i, SourcePos p) : Expr(p, ConstExprID) {
5113 type = t;
5114 type = type->GetAsConstType();
5115 AssertPos(pos, Type::Equal(type, AtomicType::UniformInt8->GetAsConstType()) ||
5116 Type::Equal(type, AtomicType::VaryingInt8->GetAsConstType()));
5117 for (int j = 0; j < Count(); ++j)
5118 int8Val[j] = i[j];
5119 }
5120
ConstExpr(const Type * t,uint8_t u,SourcePos p)5121 ConstExpr::ConstExpr(const Type *t, uint8_t u, SourcePos p) : Expr(p, ConstExprID) {
5122 type = t;
5123 type = type->GetAsConstType();
5124 AssertPos(pos, Type::Equal(type, AtomicType::UniformUInt8->GetAsConstType()));
5125 uint8Val[0] = u;
5126 }
5127
ConstExpr(const Type * t,uint8_t * u,SourcePos p)5128 ConstExpr::ConstExpr(const Type *t, uint8_t *u, SourcePos p) : Expr(p, ConstExprID) {
5129 type = t;
5130 type = type->GetAsConstType();
5131 AssertPos(pos, Type::Equal(type, AtomicType::UniformUInt8->GetAsConstType()) ||
5132 Type::Equal(type, AtomicType::VaryingUInt8->GetAsConstType()));
5133 for (int j = 0; j < Count(); ++j)
5134 uint8Val[j] = u[j];
5135 }
5136
ConstExpr(const Type * t,int16_t i,SourcePos p)5137 ConstExpr::ConstExpr(const Type *t, int16_t i, SourcePos p) : Expr(p, ConstExprID) {
5138 type = t;
5139 type = type->GetAsConstType();
5140 AssertPos(pos, Type::Equal(type, AtomicType::UniformInt16->GetAsConstType()));
5141 int16Val[0] = i;
5142 }
5143
ConstExpr(const Type * t,int16_t * i,SourcePos p)5144 ConstExpr::ConstExpr(const Type *t, int16_t *i, SourcePos p) : Expr(p, ConstExprID) {
5145 type = t;
5146 type = type->GetAsConstType();
5147 AssertPos(pos, Type::Equal(type, AtomicType::UniformInt16->GetAsConstType()) ||
5148 Type::Equal(type, AtomicType::VaryingInt16->GetAsConstType()));
5149 for (int j = 0; j < Count(); ++j)
5150 int16Val[j] = i[j];
5151 }
5152
ConstExpr(const Type * t,uint16_t u,SourcePos p)5153 ConstExpr::ConstExpr(const Type *t, uint16_t u, SourcePos p) : Expr(p, ConstExprID) {
5154 type = t;
5155 type = type->GetAsConstType();
5156 AssertPos(pos, Type::Equal(type, AtomicType::UniformUInt16->GetAsConstType()));
5157 uint16Val[0] = u;
5158 }
5159
ConstExpr(const Type * t,uint16_t * u,SourcePos p)5160 ConstExpr::ConstExpr(const Type *t, uint16_t *u, SourcePos p) : Expr(p, ConstExprID) {
5161 type = t;
5162 type = type->GetAsConstType();
5163 AssertPos(pos, Type::Equal(type, AtomicType::UniformUInt16->GetAsConstType()) ||
5164 Type::Equal(type, AtomicType::VaryingUInt16->GetAsConstType()));
5165 for (int j = 0; j < Count(); ++j)
5166 uint16Val[j] = u[j];
5167 }
5168
ConstExpr(const Type * t,int32_t i,SourcePos p)5169 ConstExpr::ConstExpr(const Type *t, int32_t i, SourcePos p) : Expr(p, ConstExprID) {
5170 type = t;
5171 type = type->GetAsConstType();
5172 AssertPos(pos, Type::Equal(type, AtomicType::UniformInt32->GetAsConstType()));
5173 int32Val[0] = i;
5174 }
5175
ConstExpr(const Type * t,int32_t * i,SourcePos p)5176 ConstExpr::ConstExpr(const Type *t, int32_t *i, SourcePos p) : Expr(p, ConstExprID) {
5177 type = t;
5178 type = type->GetAsConstType();
5179 AssertPos(pos, Type::Equal(type, AtomicType::UniformInt32->GetAsConstType()) ||
5180 Type::Equal(type, AtomicType::VaryingInt32->GetAsConstType()));
5181 for (int j = 0; j < Count(); ++j)
5182 int32Val[j] = i[j];
5183 }
5184
ConstExpr(const Type * t,uint32_t u,SourcePos p)5185 ConstExpr::ConstExpr(const Type *t, uint32_t u, SourcePos p) : Expr(p, ConstExprID) {
5186 type = t;
5187 type = type->GetAsConstType();
5188 AssertPos(pos, Type::Equal(type, AtomicType::UniformUInt32->GetAsConstType()) ||
5189 (CastType<EnumType>(type) != NULL && type->IsUniformType()));
5190 uint32Val[0] = u;
5191 }
5192
ConstExpr(const Type * t,uint32_t * u,SourcePos p)5193 ConstExpr::ConstExpr(const Type *t, uint32_t *u, SourcePos p) : Expr(p, ConstExprID) {
5194 type = t;
5195 type = type->GetAsConstType();
5196 AssertPos(pos, Type::Equal(type, AtomicType::UniformUInt32->GetAsConstType()) ||
5197 Type::Equal(type, AtomicType::VaryingUInt32->GetAsConstType()) ||
5198 (CastType<EnumType>(type) != NULL));
5199 for (int j = 0; j < Count(); ++j)
5200 uint32Val[j] = u[j];
5201 }
5202
ConstExpr(const Type * t,float f,SourcePos p)5203 ConstExpr::ConstExpr(const Type *t, float f, SourcePos p) : Expr(p, ConstExprID) {
5204 type = t;
5205 type = type->GetAsConstType();
5206 AssertPos(pos, Type::Equal(type, AtomicType::UniformFloat->GetAsConstType()));
5207 floatVal[0] = f;
5208 }
5209
ConstExpr(const Type * t,float * f,SourcePos p)5210 ConstExpr::ConstExpr(const Type *t, float *f, SourcePos p) : Expr(p, ConstExprID) {
5211 type = t;
5212 type = type->GetAsConstType();
5213 AssertPos(pos, Type::Equal(type, AtomicType::UniformFloat->GetAsConstType()) ||
5214 Type::Equal(type, AtomicType::VaryingFloat->GetAsConstType()));
5215 for (int j = 0; j < Count(); ++j)
5216 floatVal[j] = f[j];
5217 }
5218
ConstExpr(const Type * t,int64_t i,SourcePos p)5219 ConstExpr::ConstExpr(const Type *t, int64_t i, SourcePos p) : Expr(p, ConstExprID) {
5220 type = t;
5221 type = type->GetAsConstType();
5222 AssertPos(pos, Type::Equal(type, AtomicType::UniformInt64->GetAsConstType()));
5223 int64Val[0] = i;
5224 }
5225
ConstExpr(const Type * t,int64_t * i,SourcePos p)5226 ConstExpr::ConstExpr(const Type *t, int64_t *i, SourcePos p) : Expr(p, ConstExprID) {
5227 type = t;
5228 type = type->GetAsConstType();
5229 AssertPos(pos, Type::Equal(type, AtomicType::UniformInt64->GetAsConstType()) ||
5230 Type::Equal(type, AtomicType::VaryingInt64->GetAsConstType()));
5231 for (int j = 0; j < Count(); ++j)
5232 int64Val[j] = i[j];
5233 }
5234
ConstExpr(const Type * t,uint64_t u,SourcePos p)5235 ConstExpr::ConstExpr(const Type *t, uint64_t u, SourcePos p) : Expr(p, ConstExprID) {
5236 type = t;
5237 type = type->GetAsConstType();
5238 AssertPos(pos, Type::Equal(type, AtomicType::UniformUInt64->GetAsConstType()));
5239 uint64Val[0] = u;
5240 }
5241
ConstExpr(const Type * t,uint64_t * u,SourcePos p)5242 ConstExpr::ConstExpr(const Type *t, uint64_t *u, SourcePos p) : Expr(p, ConstExprID) {
5243 type = t;
5244 type = type->GetAsConstType();
5245 AssertPos(pos, Type::Equal(type, AtomicType::UniformUInt64->GetAsConstType()) ||
5246 Type::Equal(type, AtomicType::VaryingUInt64->GetAsConstType()));
5247 for (int j = 0; j < Count(); ++j)
5248 uint64Val[j] = u[j];
5249 }
5250
ConstExpr(const Type * t,double f,SourcePos p)5251 ConstExpr::ConstExpr(const Type *t, double f, SourcePos p) : Expr(p, ConstExprID) {
5252 type = t;
5253 type = type->GetAsConstType();
5254 AssertPos(pos, Type::Equal(type, AtomicType::UniformDouble->GetAsConstType()));
5255 doubleVal[0] = f;
5256 }
5257
ConstExpr(const Type * t,double * f,SourcePos p)5258 ConstExpr::ConstExpr(const Type *t, double *f, SourcePos p) : Expr(p, ConstExprID) {
5259 type = t;
5260 type = type->GetAsConstType();
5261 AssertPos(pos, Type::Equal(type, AtomicType::UniformDouble->GetAsConstType()) ||
5262 Type::Equal(type, AtomicType::VaryingDouble->GetAsConstType()));
5263 for (int j = 0; j < Count(); ++j)
5264 doubleVal[j] = f[j];
5265 }
5266
ConstExpr(const Type * t,bool b,SourcePos p)5267 ConstExpr::ConstExpr(const Type *t, bool b, SourcePos p) : Expr(p, ConstExprID) {
5268 type = t;
5269 type = type->GetAsConstType();
5270 AssertPos(pos, Type::Equal(type, AtomicType::UniformBool->GetAsConstType()));
5271 boolVal[0] = b;
5272 }
5273
ConstExpr(const Type * t,bool * b,SourcePos p)5274 ConstExpr::ConstExpr(const Type *t, bool *b, SourcePos p) : Expr(p, ConstExprID) {
5275 type = t;
5276 type = type->GetAsConstType();
5277 AssertPos(pos, Type::Equal(type, AtomicType::UniformBool->GetAsConstType()) ||
5278 Type::Equal(type, AtomicType::VaryingBool->GetAsConstType()));
5279 for (int j = 0; j < Count(); ++j)
5280 boolVal[j] = b[j];
5281 }
5282
ConstExpr(ConstExpr * old,double * v)5283 ConstExpr::ConstExpr(ConstExpr *old, double *v) : Expr(old->pos, ConstExprID) {
5284 type = old->type;
5285
5286 AtomicType::BasicType basicType = getBasicType();
5287
5288 switch (basicType) {
5289 case AtomicType::TYPE_BOOL:
5290 for (int i = 0; i < Count(); ++i)
5291 boolVal[i] = (v[i] != 0.);
5292 break;
5293 case AtomicType::TYPE_INT8:
5294 for (int i = 0; i < Count(); ++i)
5295 int8Val[i] = (int)v[i];
5296 break;
5297 case AtomicType::TYPE_UINT8:
5298 for (int i = 0; i < Count(); ++i)
5299 uint8Val[i] = (unsigned int)v[i];
5300 break;
5301 case AtomicType::TYPE_INT16:
5302 for (int i = 0; i < Count(); ++i)
5303 int16Val[i] = (int)v[i];
5304 break;
5305 case AtomicType::TYPE_UINT16:
5306 for (int i = 0; i < Count(); ++i)
5307 uint16Val[i] = (unsigned int)v[i];
5308 break;
5309 case AtomicType::TYPE_INT32:
5310 for (int i = 0; i < Count(); ++i)
5311 int32Val[i] = (int)v[i];
5312 break;
5313 case AtomicType::TYPE_UINT32:
5314 for (int i = 0; i < Count(); ++i)
5315 uint32Val[i] = (unsigned int)v[i];
5316 break;
5317 case AtomicType::TYPE_FLOAT:
5318 for (int i = 0; i < Count(); ++i)
5319 floatVal[i] = (float)v[i];
5320 break;
5321 case AtomicType::TYPE_DOUBLE:
5322 for (int i = 0; i < Count(); ++i)
5323 doubleVal[i] = v[i];
5324 break;
5325 case AtomicType::TYPE_INT64:
5326 case AtomicType::TYPE_UINT64:
5327 // For now, this should never be reached
5328 FATAL("fixme; we need another constructor so that we're not trying to pass "
5329 "double values to init an int64 type...");
5330 default:
5331 FATAL("unimplemented const type");
5332 }
5333 }
5334
ConstExpr(ConstExpr * old,SourcePos p)5335 ConstExpr::ConstExpr(ConstExpr *old, SourcePos p) : Expr(p, ConstExprID) {
5336 type = old->type;
5337
5338 AtomicType::BasicType basicType = getBasicType();
5339
5340 switch (basicType) {
5341 case AtomicType::TYPE_BOOL:
5342 memcpy(boolVal, old->boolVal, Count() * sizeof(bool));
5343 break;
5344 case AtomicType::TYPE_INT8:
5345 memcpy(int8Val, old->int8Val, Count() * sizeof(int8_t));
5346 break;
5347 case AtomicType::TYPE_UINT8:
5348 memcpy(uint8Val, old->uint8Val, Count() * sizeof(uint8_t));
5349 break;
5350 case AtomicType::TYPE_INT16:
5351 memcpy(int16Val, old->int16Val, Count() * sizeof(int16_t));
5352 break;
5353 case AtomicType::TYPE_UINT16:
5354 memcpy(uint16Val, old->uint16Val, Count() * sizeof(uint16_t));
5355 break;
5356 case AtomicType::TYPE_INT32:
5357 memcpy(int32Val, old->int32Val, Count() * sizeof(int32_t));
5358 break;
5359 case AtomicType::TYPE_UINT32:
5360 memcpy(uint32Val, old->uint32Val, Count() * sizeof(uint32_t));
5361 break;
5362 case AtomicType::TYPE_FLOAT:
5363 memcpy(floatVal, old->floatVal, Count() * sizeof(float));
5364 break;
5365 case AtomicType::TYPE_DOUBLE:
5366 memcpy(doubleVal, old->doubleVal, Count() * sizeof(double));
5367 break;
5368 case AtomicType::TYPE_INT64:
5369 memcpy(int64Val, old->int64Val, Count() * sizeof(int64_t));
5370 break;
5371 case AtomicType::TYPE_UINT64:
5372 memcpy(uint64Val, old->uint64Val, Count() * sizeof(uint64_t));
5373 break;
5374 default:
5375 FATAL("unimplemented const type");
5376 }
5377 }
5378
getBasicType() const5379 AtomicType::BasicType ConstExpr::getBasicType() const {
5380 const AtomicType *at = CastType<AtomicType>(type);
5381 if (at != NULL)
5382 return at->basicType;
5383 else {
5384 AssertPos(pos, CastType<EnumType>(type) != NULL);
5385 return AtomicType::TYPE_UINT32;
5386 }
5387 }
5388
GetType() const5389 const Type *ConstExpr::GetType() const { return type; }
5390
GetValue(FunctionEmitContext * ctx) const5391 llvm::Value *ConstExpr::GetValue(FunctionEmitContext *ctx) const {
5392 ctx->SetDebugPos(pos);
5393 bool isVarying = type->IsVaryingType();
5394
5395 AtomicType::BasicType basicType = getBasicType();
5396
5397 switch (basicType) {
5398 case AtomicType::TYPE_BOOL:
5399 if (isVarying)
5400 return LLVMBoolVector(boolVal);
5401 else
5402 return boolVal[0] ? LLVMTrue : LLVMFalse;
5403 case AtomicType::TYPE_INT8:
5404 return isVarying ? LLVMInt8Vector(int8Val) : LLVMInt8(int8Val[0]);
5405 case AtomicType::TYPE_UINT8:
5406 return isVarying ? LLVMUInt8Vector(uint8Val) : LLVMUInt8(uint8Val[0]);
5407 case AtomicType::TYPE_INT16:
5408 return isVarying ? LLVMInt16Vector(int16Val) : LLVMInt16(int16Val[0]);
5409 case AtomicType::TYPE_UINT16:
5410 return isVarying ? LLVMUInt16Vector(uint16Val) : LLVMUInt16(uint16Val[0]);
5411 case AtomicType::TYPE_INT32:
5412 return isVarying ? LLVMInt32Vector(int32Val) : LLVMInt32(int32Val[0]);
5413 case AtomicType::TYPE_UINT32:
5414 return isVarying ? LLVMUInt32Vector(uint32Val) : LLVMUInt32(uint32Val[0]);
5415 case AtomicType::TYPE_FLOAT:
5416 return isVarying ? LLVMFloatVector(floatVal) : LLVMFloat(floatVal[0]);
5417 case AtomicType::TYPE_INT64:
5418 return isVarying ? LLVMInt64Vector(int64Val) : LLVMInt64(int64Val[0]);
5419 case AtomicType::TYPE_UINT64:
5420 return isVarying ? LLVMUInt64Vector(uint64Val) : LLVMUInt64(uint64Val[0]);
5421 case AtomicType::TYPE_DOUBLE:
5422 return isVarying ? LLVMDoubleVector(doubleVal) : LLVMDouble(doubleVal[0]);
5423 default:
5424 FATAL("unimplemented const type");
5425 return NULL;
5426 }
5427 }
5428
5429 /* Type conversion templates: take advantage of C++ function overloading
5430 rules to get the one we want to match. */
5431
5432 /* First the most general case, just use C++ type conversion if nothing
5433 else matches */
lConvertElement(From from,To * to)5434 template <typename From, typename To> static inline void lConvertElement(From from, To *to) { *to = (To)from; }
5435
5436 /** When converting from bool types to numeric types, make sure the result
5437 is one or zero.
5438 */
lConvertElement(bool from,To * to)5439 template <typename To> static inline void lConvertElement(bool from, To *to) { *to = from ? (To)1 : (To)0; }
5440
5441 /** When converting numeric types to bool, compare to zero. (Do we
5442 actually need this one??) */
lConvertElement(From from,bool * to)5443 template <typename From> static inline void lConvertElement(From from, bool *to) { *to = (from != 0); }
5444
5445 /** And bool -> bool is just assignment */
lConvertElement(bool from,bool * to)5446 static inline void lConvertElement(bool from, bool *to) { *to = from; }
5447
5448 /** Type conversion utility function
5449 */
lConvert(const From * from,To * to,int count,bool forceVarying)5450 template <typename From, typename To> static void lConvert(const From *from, To *to, int count, bool forceVarying) {
5451 for (int i = 0; i < count; ++i)
5452 lConvertElement(from[i], &to[i]);
5453
5454 if (forceVarying && count == 1)
5455 for (int i = 1; i < g->target->getVectorWidth(); ++i)
5456 to[i] = to[0];
5457 }
5458
GetValues(int64_t * ip,bool forceVarying) const5459 int ConstExpr::GetValues(int64_t *ip, bool forceVarying) const {
5460 switch (getBasicType()) {
5461 case AtomicType::TYPE_BOOL:
5462 lConvert(boolVal, ip, Count(), forceVarying);
5463 break;
5464 case AtomicType::TYPE_INT8:
5465 lConvert(int8Val, ip, Count(), forceVarying);
5466 break;
5467 case AtomicType::TYPE_UINT8:
5468 lConvert(uint8Val, ip, Count(), forceVarying);
5469 break;
5470 case AtomicType::TYPE_INT16:
5471 lConvert(int16Val, ip, Count(), forceVarying);
5472 break;
5473 case AtomicType::TYPE_UINT16:
5474 lConvert(uint16Val, ip, Count(), forceVarying);
5475 break;
5476 case AtomicType::TYPE_INT32:
5477 lConvert(int32Val, ip, Count(), forceVarying);
5478 break;
5479 case AtomicType::TYPE_UINT32:
5480 lConvert(uint32Val, ip, Count(), forceVarying);
5481 break;
5482 case AtomicType::TYPE_FLOAT:
5483 lConvert(floatVal, ip, Count(), forceVarying);
5484 break;
5485 case AtomicType::TYPE_DOUBLE:
5486 lConvert(doubleVal, ip, Count(), forceVarying);
5487 break;
5488 case AtomicType::TYPE_INT64:
5489 lConvert(int64Val, ip, Count(), forceVarying);
5490 break;
5491 case AtomicType::TYPE_UINT64:
5492 lConvert(uint64Val, ip, Count(), forceVarying);
5493 break;
5494 default:
5495 FATAL("unimplemented const type");
5496 }
5497 return Count();
5498 }
5499
GetValues(uint64_t * up,bool forceVarying) const5500 int ConstExpr::GetValues(uint64_t *up, bool forceVarying) const {
5501 switch (getBasicType()) {
5502 case AtomicType::TYPE_BOOL:
5503 lConvert(boolVal, up, Count(), forceVarying);
5504 break;
5505 case AtomicType::TYPE_INT8:
5506 lConvert(int8Val, up, Count(), forceVarying);
5507 break;
5508 case AtomicType::TYPE_UINT8:
5509 lConvert(uint8Val, up, Count(), forceVarying);
5510 break;
5511 case AtomicType::TYPE_INT16:
5512 lConvert(int16Val, up, Count(), forceVarying);
5513 break;
5514 case AtomicType::TYPE_UINT16:
5515 lConvert(uint16Val, up, Count(), forceVarying);
5516 break;
5517 case AtomicType::TYPE_INT32:
5518 lConvert(int32Val, up, Count(), forceVarying);
5519 break;
5520 case AtomicType::TYPE_UINT32:
5521 lConvert(uint32Val, up, Count(), forceVarying);
5522 break;
5523 case AtomicType::TYPE_FLOAT:
5524 lConvert(floatVal, up, Count(), forceVarying);
5525 break;
5526 case AtomicType::TYPE_DOUBLE:
5527 lConvert(doubleVal, up, Count(), forceVarying);
5528 break;
5529 case AtomicType::TYPE_INT64:
5530 lConvert(int64Val, up, Count(), forceVarying);
5531 break;
5532 case AtomicType::TYPE_UINT64:
5533 lConvert(uint64Val, up, Count(), forceVarying);
5534 break;
5535 default:
5536 FATAL("unimplemented const type");
5537 }
5538 return Count();
5539 }
5540
GetValues(double * d,bool forceVarying) const5541 int ConstExpr::GetValues(double *d, bool forceVarying) const {
5542 switch (getBasicType()) {
5543 case AtomicType::TYPE_BOOL:
5544 lConvert(boolVal, d, Count(), forceVarying);
5545 break;
5546 case AtomicType::TYPE_INT8:
5547 lConvert(int8Val, d, Count(), forceVarying);
5548 break;
5549 case AtomicType::TYPE_UINT8:
5550 lConvert(uint8Val, d, Count(), forceVarying);
5551 break;
5552 case AtomicType::TYPE_INT16:
5553 lConvert(int16Val, d, Count(), forceVarying);
5554 break;
5555 case AtomicType::TYPE_UINT16:
5556 lConvert(uint16Val, d, Count(), forceVarying);
5557 break;
5558 case AtomicType::TYPE_INT32:
5559 lConvert(int32Val, d, Count(), forceVarying);
5560 break;
5561 case AtomicType::TYPE_UINT32:
5562 lConvert(uint32Val, d, Count(), forceVarying);
5563 break;
5564 case AtomicType::TYPE_FLOAT:
5565 lConvert(floatVal, d, Count(), forceVarying);
5566 break;
5567 case AtomicType::TYPE_DOUBLE:
5568 lConvert(doubleVal, d, Count(), forceVarying);
5569 break;
5570 case AtomicType::TYPE_INT64:
5571 lConvert(int64Val, d, Count(), forceVarying);
5572 break;
5573 case AtomicType::TYPE_UINT64:
5574 lConvert(uint64Val, d, Count(), forceVarying);
5575 break;
5576 default:
5577 FATAL("unimplemented const type");
5578 }
5579 return Count();
5580 }
5581
GetValues(float * fp,bool forceVarying) const5582 int ConstExpr::GetValues(float *fp, bool forceVarying) const {
5583 switch (getBasicType()) {
5584 case AtomicType::TYPE_BOOL:
5585 lConvert(boolVal, fp, Count(), forceVarying);
5586 break;
5587 case AtomicType::TYPE_INT8:
5588 lConvert(int8Val, fp, Count(), forceVarying);
5589 break;
5590 case AtomicType::TYPE_UINT8:
5591 lConvert(uint8Val, fp, Count(), forceVarying);
5592 break;
5593 case AtomicType::TYPE_INT16:
5594 lConvert(int16Val, fp, Count(), forceVarying);
5595 break;
5596 case AtomicType::TYPE_UINT16:
5597 lConvert(uint16Val, fp, Count(), forceVarying);
5598 break;
5599 case AtomicType::TYPE_INT32:
5600 lConvert(int32Val, fp, Count(), forceVarying);
5601 break;
5602 case AtomicType::TYPE_UINT32:
5603 lConvert(uint32Val, fp, Count(), forceVarying);
5604 break;
5605 case AtomicType::TYPE_FLOAT:
5606 lConvert(floatVal, fp, Count(), forceVarying);
5607 break;
5608 case AtomicType::TYPE_DOUBLE:
5609 lConvert(doubleVal, fp, Count(), forceVarying);
5610 break;
5611 case AtomicType::TYPE_INT64:
5612 lConvert(int64Val, fp, Count(), forceVarying);
5613 break;
5614 case AtomicType::TYPE_UINT64:
5615 lConvert(uint64Val, fp, Count(), forceVarying);
5616 break;
5617 default:
5618 FATAL("unimplemented const type");
5619 }
5620 return Count();
5621 }
5622
GetValues(bool * b,bool forceVarying) const5623 int ConstExpr::GetValues(bool *b, bool forceVarying) const {
5624 switch (getBasicType()) {
5625 case AtomicType::TYPE_BOOL:
5626 lConvert(boolVal, b, Count(), forceVarying);
5627 break;
5628 case AtomicType::TYPE_INT8:
5629 lConvert(int8Val, b, Count(), forceVarying);
5630 break;
5631 case AtomicType::TYPE_UINT8:
5632 lConvert(uint8Val, b, Count(), forceVarying);
5633 break;
5634 case AtomicType::TYPE_INT16:
5635 lConvert(int16Val, b, Count(), forceVarying);
5636 break;
5637 case AtomicType::TYPE_UINT16:
5638 lConvert(uint16Val, b, Count(), forceVarying);
5639 break;
5640 case AtomicType::TYPE_INT32:
5641 lConvert(int32Val, b, Count(), forceVarying);
5642 break;
5643 case AtomicType::TYPE_UINT32:
5644 lConvert(uint32Val, b, Count(), forceVarying);
5645 break;
5646 case AtomicType::TYPE_FLOAT:
5647 lConvert(floatVal, b, Count(), forceVarying);
5648 break;
5649 case AtomicType::TYPE_DOUBLE:
5650 lConvert(doubleVal, b, Count(), forceVarying);
5651 break;
5652 case AtomicType::TYPE_INT64:
5653 lConvert(int64Val, b, Count(), forceVarying);
5654 break;
5655 case AtomicType::TYPE_UINT64:
5656 lConvert(uint64Val, b, Count(), forceVarying);
5657 break;
5658 default:
5659 FATAL("unimplemented const type");
5660 }
5661 return Count();
5662 }
5663
GetValues(int8_t * ip,bool forceVarying) const5664 int ConstExpr::GetValues(int8_t *ip, bool forceVarying) const {
5665 switch (getBasicType()) {
5666 case AtomicType::TYPE_BOOL:
5667 lConvert(boolVal, ip, Count(), forceVarying);
5668 break;
5669 case AtomicType::TYPE_INT8:
5670 lConvert(int8Val, ip, Count(), forceVarying);
5671 break;
5672 case AtomicType::TYPE_UINT8:
5673 lConvert(uint8Val, ip, Count(), forceVarying);
5674 break;
5675 case AtomicType::TYPE_INT16:
5676 lConvert(int16Val, ip, Count(), forceVarying);
5677 break;
5678 case AtomicType::TYPE_UINT16:
5679 lConvert(uint16Val, ip, Count(), forceVarying);
5680 break;
5681 case AtomicType::TYPE_INT32:
5682 lConvert(int32Val, ip, Count(), forceVarying);
5683 break;
5684 case AtomicType::TYPE_UINT32:
5685 lConvert(uint32Val, ip, Count(), forceVarying);
5686 break;
5687 case AtomicType::TYPE_FLOAT:
5688 lConvert(floatVal, ip, Count(), forceVarying);
5689 break;
5690 case AtomicType::TYPE_DOUBLE:
5691 lConvert(doubleVal, ip, Count(), forceVarying);
5692 break;
5693 case AtomicType::TYPE_INT64:
5694 lConvert(int64Val, ip, Count(), forceVarying);
5695 break;
5696 case AtomicType::TYPE_UINT64:
5697 lConvert(uint64Val, ip, Count(), forceVarying);
5698 break;
5699 default:
5700 FATAL("unimplemented const type");
5701 }
5702 return Count();
5703 }
5704
GetValues(uint8_t * up,bool forceVarying) const5705 int ConstExpr::GetValues(uint8_t *up, bool forceVarying) const {
5706 switch (getBasicType()) {
5707 case AtomicType::TYPE_BOOL:
5708 lConvert(boolVal, up, Count(), forceVarying);
5709 break;
5710 case AtomicType::TYPE_INT8:
5711 lConvert(int8Val, up, Count(), forceVarying);
5712 break;
5713 case AtomicType::TYPE_UINT8:
5714 lConvert(uint8Val, up, Count(), forceVarying);
5715 break;
5716 case AtomicType::TYPE_INT16:
5717 lConvert(int16Val, up, Count(), forceVarying);
5718 break;
5719 case AtomicType::TYPE_UINT16:
5720 lConvert(uint16Val, up, Count(), forceVarying);
5721 break;
5722 case AtomicType::TYPE_INT32:
5723 lConvert(int32Val, up, Count(), forceVarying);
5724 break;
5725 case AtomicType::TYPE_UINT32:
5726 lConvert(uint32Val, up, Count(), forceVarying);
5727 break;
5728 case AtomicType::TYPE_FLOAT:
5729 lConvert(floatVal, up, Count(), forceVarying);
5730 break;
5731 case AtomicType::TYPE_DOUBLE:
5732 lConvert(doubleVal, up, Count(), forceVarying);
5733 break;
5734 case AtomicType::TYPE_INT64:
5735 lConvert(int64Val, up, Count(), forceVarying);
5736 break;
5737 case AtomicType::TYPE_UINT64:
5738 lConvert(uint64Val, up, Count(), forceVarying);
5739 break;
5740 default:
5741 FATAL("unimplemented const type");
5742 }
5743 return Count();
5744 }
5745
GetValues(int16_t * ip,bool forceVarying) const5746 int ConstExpr::GetValues(int16_t *ip, bool forceVarying) const {
5747 switch (getBasicType()) {
5748 case AtomicType::TYPE_BOOL:
5749 lConvert(boolVal, ip, Count(), forceVarying);
5750 break;
5751 case AtomicType::TYPE_INT8:
5752 lConvert(int8Val, ip, Count(), forceVarying);
5753 break;
5754 case AtomicType::TYPE_UINT8:
5755 lConvert(uint8Val, ip, Count(), forceVarying);
5756 break;
5757 case AtomicType::TYPE_INT16:
5758 lConvert(int16Val, ip, Count(), forceVarying);
5759 break;
5760 case AtomicType::TYPE_UINT16:
5761 lConvert(uint16Val, ip, Count(), forceVarying);
5762 break;
5763 case AtomicType::TYPE_INT32:
5764 lConvert(int32Val, ip, Count(), forceVarying);
5765 break;
5766 case AtomicType::TYPE_UINT32:
5767 lConvert(uint32Val, ip, Count(), forceVarying);
5768 break;
5769 case AtomicType::TYPE_FLOAT:
5770 lConvert(floatVal, ip, Count(), forceVarying);
5771 break;
5772 case AtomicType::TYPE_DOUBLE:
5773 lConvert(doubleVal, ip, Count(), forceVarying);
5774 break;
5775 case AtomicType::TYPE_INT64:
5776 lConvert(int64Val, ip, Count(), forceVarying);
5777 break;
5778 case AtomicType::TYPE_UINT64:
5779 lConvert(uint64Val, ip, Count(), forceVarying);
5780 break;
5781 default:
5782 FATAL("unimplemented const type");
5783 }
5784 return Count();
5785 }
5786
GetValues(uint16_t * up,bool forceVarying) const5787 int ConstExpr::GetValues(uint16_t *up, bool forceVarying) const {
5788 switch (getBasicType()) {
5789 case AtomicType::TYPE_BOOL:
5790 lConvert(boolVal, up, Count(), forceVarying);
5791 break;
5792 case AtomicType::TYPE_INT8:
5793 lConvert(int8Val, up, Count(), forceVarying);
5794 break;
5795 case AtomicType::TYPE_UINT8:
5796 lConvert(uint8Val, up, Count(), forceVarying);
5797 break;
5798 case AtomicType::TYPE_INT16:
5799 lConvert(int16Val, up, Count(), forceVarying);
5800 break;
5801 case AtomicType::TYPE_UINT16:
5802 lConvert(uint16Val, up, Count(), forceVarying);
5803 break;
5804 case AtomicType::TYPE_INT32:
5805 lConvert(int32Val, up, Count(), forceVarying);
5806 break;
5807 case AtomicType::TYPE_UINT32:
5808 lConvert(uint32Val, up, Count(), forceVarying);
5809 break;
5810 case AtomicType::TYPE_FLOAT:
5811 lConvert(floatVal, up, Count(), forceVarying);
5812 break;
5813 case AtomicType::TYPE_DOUBLE:
5814 lConvert(doubleVal, up, Count(), forceVarying);
5815 break;
5816 case AtomicType::TYPE_INT64:
5817 lConvert(int64Val, up, Count(), forceVarying);
5818 break;
5819 case AtomicType::TYPE_UINT64:
5820 lConvert(uint64Val, up, Count(), forceVarying);
5821 break;
5822 default:
5823 FATAL("unimplemented const type");
5824 }
5825 return Count();
5826 }
5827
GetValues(int32_t * ip,bool forceVarying) const5828 int ConstExpr::GetValues(int32_t *ip, bool forceVarying) const {
5829 switch (getBasicType()) {
5830 case AtomicType::TYPE_BOOL:
5831 lConvert(boolVal, ip, Count(), forceVarying);
5832 break;
5833 case AtomicType::TYPE_INT8:
5834 lConvert(int8Val, ip, Count(), forceVarying);
5835 break;
5836 case AtomicType::TYPE_UINT8:
5837 lConvert(uint8Val, ip, Count(), forceVarying);
5838 break;
5839 case AtomicType::TYPE_INT16:
5840 lConvert(int16Val, ip, Count(), forceVarying);
5841 break;
5842 case AtomicType::TYPE_UINT16:
5843 lConvert(uint16Val, ip, Count(), forceVarying);
5844 break;
5845 case AtomicType::TYPE_INT32:
5846 lConvert(int32Val, ip, Count(), forceVarying);
5847 break;
5848 case AtomicType::TYPE_UINT32:
5849 lConvert(uint32Val, ip, Count(), forceVarying);
5850 break;
5851 case AtomicType::TYPE_FLOAT:
5852 lConvert(floatVal, ip, Count(), forceVarying);
5853 break;
5854 case AtomicType::TYPE_DOUBLE:
5855 lConvert(doubleVal, ip, Count(), forceVarying);
5856 break;
5857 case AtomicType::TYPE_INT64:
5858 lConvert(int64Val, ip, Count(), forceVarying);
5859 break;
5860 case AtomicType::TYPE_UINT64:
5861 lConvert(uint64Val, ip, Count(), forceVarying);
5862 break;
5863 default:
5864 FATAL("unimplemented const type");
5865 }
5866 return Count();
5867 }
5868
GetValues(uint32_t * up,bool forceVarying) const5869 int ConstExpr::GetValues(uint32_t *up, bool forceVarying) const {
5870 switch (getBasicType()) {
5871 case AtomicType::TYPE_BOOL:
5872 lConvert(boolVal, up, Count(), forceVarying);
5873 break;
5874 case AtomicType::TYPE_INT8:
5875 lConvert(int8Val, up, Count(), forceVarying);
5876 break;
5877 case AtomicType::TYPE_UINT8:
5878 lConvert(uint8Val, up, Count(), forceVarying);
5879 break;
5880 case AtomicType::TYPE_INT16:
5881 lConvert(int16Val, up, Count(), forceVarying);
5882 break;
5883 case AtomicType::TYPE_UINT16:
5884 lConvert(uint16Val, up, Count(), forceVarying);
5885 break;
5886 case AtomicType::TYPE_INT32:
5887 lConvert(int32Val, up, Count(), forceVarying);
5888 break;
5889 case AtomicType::TYPE_UINT32:
5890 lConvert(uint32Val, up, Count(), forceVarying);
5891 break;
5892 case AtomicType::TYPE_FLOAT:
5893 lConvert(floatVal, up, Count(), forceVarying);
5894 break;
5895 case AtomicType::TYPE_DOUBLE:
5896 lConvert(doubleVal, up, Count(), forceVarying);
5897 break;
5898 case AtomicType::TYPE_INT64:
5899 lConvert(int64Val, up, Count(), forceVarying);
5900 break;
5901 case AtomicType::TYPE_UINT64:
5902 lConvert(uint64Val, up, Count(), forceVarying);
5903 break;
5904 default:
5905 FATAL("unimplemented const type");
5906 }
5907 return Count();
5908 }
5909
Count() const5910 int ConstExpr::Count() const { return GetType()->IsVaryingType() ? g->target->getVectorWidth() : 1; }
5911
lGetConstExprConstant(const Type * constType,const ConstExpr * cExpr,bool isStorageType)5912 static std::pair<llvm::Constant *, bool> lGetConstExprConstant(const Type *constType, const ConstExpr *cExpr,
5913 bool isStorageType) {
5914 // Caller shouldn't be trying to stuff a varying value here into a
5915 // constant type.
5916 SourcePos pos = cExpr->pos;
5917 bool isNotValidForMultiTargetGlobal = false;
5918 if (constType->IsUniformType())
5919 AssertPos(pos, cExpr->Count() == 1);
5920
5921 constType = constType->GetAsNonConstType();
5922 if (Type::Equal(constType, AtomicType::UniformBool) || Type::Equal(constType, AtomicType::VaryingBool)) {
5923 bool bv[ISPC_MAX_NVEC];
5924 cExpr->GetValues(bv, constType->IsVaryingType());
5925 if (constType->IsUniformType()) {
5926 if (isStorageType)
5927 return std::pair<llvm::Constant *, bool>(bv[0] ? LLVMTrueInStorage : LLVMFalseInStorage,
5928 isNotValidForMultiTargetGlobal);
5929 else
5930 return std::pair<llvm::Constant *, bool>(bv[0] ? LLVMTrue : LLVMFalse, isNotValidForMultiTargetGlobal);
5931 } else {
5932 if (isStorageType)
5933 return std::pair<llvm::Constant *, bool>(LLVMBoolVectorInStorage(bv), isNotValidForMultiTargetGlobal);
5934 else
5935 return std::pair<llvm::Constant *, bool>(LLVMBoolVector(bv), isNotValidForMultiTargetGlobal);
5936 }
5937 } else if (Type::Equal(constType, AtomicType::UniformInt8) || Type::Equal(constType, AtomicType::VaryingInt8)) {
5938 int8_t iv[ISPC_MAX_NVEC];
5939 cExpr->GetValues(iv, constType->IsVaryingType());
5940 if (constType->IsUniformType())
5941 return std::pair<llvm::Constant *, bool>(LLVMInt8(iv[0]), isNotValidForMultiTargetGlobal);
5942 else
5943 return std::pair<llvm::Constant *, bool>(LLVMInt8Vector(iv), isNotValidForMultiTargetGlobal);
5944 } else if (Type::Equal(constType, AtomicType::UniformUInt8) || Type::Equal(constType, AtomicType::VaryingUInt8)) {
5945 uint8_t uiv[ISPC_MAX_NVEC];
5946 cExpr->GetValues(uiv, constType->IsVaryingType());
5947 if (constType->IsUniformType())
5948 return std::pair<llvm::Constant *, bool>(LLVMUInt8(uiv[0]), isNotValidForMultiTargetGlobal);
5949 else
5950 return std::pair<llvm::Constant *, bool>(LLVMUInt8Vector(uiv), isNotValidForMultiTargetGlobal);
5951 } else if (Type::Equal(constType, AtomicType::UniformInt16) || Type::Equal(constType, AtomicType::VaryingInt16)) {
5952 int16_t iv[ISPC_MAX_NVEC];
5953 cExpr->GetValues(iv, constType->IsVaryingType());
5954 if (constType->IsUniformType())
5955 return std::pair<llvm::Constant *, bool>(LLVMInt16(iv[0]), isNotValidForMultiTargetGlobal);
5956 else
5957 return std::pair<llvm::Constant *, bool>(LLVMInt16Vector(iv), isNotValidForMultiTargetGlobal);
5958 } else if (Type::Equal(constType, AtomicType::UniformUInt16) || Type::Equal(constType, AtomicType::VaryingUInt16)) {
5959 uint16_t uiv[ISPC_MAX_NVEC];
5960 cExpr->GetValues(uiv, constType->IsVaryingType());
5961 if (constType->IsUniformType())
5962 return std::pair<llvm::Constant *, bool>(LLVMUInt16(uiv[0]), isNotValidForMultiTargetGlobal);
5963 else
5964 return std::pair<llvm::Constant *, bool>(LLVMUInt16Vector(uiv), isNotValidForMultiTargetGlobal);
5965 } else if (Type::Equal(constType, AtomicType::UniformInt32) || Type::Equal(constType, AtomicType::VaryingInt32)) {
5966 int32_t iv[ISPC_MAX_NVEC];
5967 cExpr->GetValues(iv, constType->IsVaryingType());
5968 if (constType->IsUniformType())
5969 return std::pair<llvm::Constant *, bool>(LLVMInt32(iv[0]), isNotValidForMultiTargetGlobal);
5970 else
5971 return std::pair<llvm::Constant *, bool>(LLVMInt32Vector(iv), isNotValidForMultiTargetGlobal);
5972 } else if (Type::Equal(constType, AtomicType::UniformUInt32) || Type::Equal(constType, AtomicType::VaryingUInt32) ||
5973 CastType<EnumType>(constType) != NULL) {
5974 uint32_t uiv[ISPC_MAX_NVEC];
5975 cExpr->GetValues(uiv, constType->IsVaryingType());
5976 if (constType->IsUniformType())
5977 return std::pair<llvm::Constant *, bool>(LLVMUInt32(uiv[0]), isNotValidForMultiTargetGlobal);
5978 else
5979 return std::pair<llvm::Constant *, bool>(LLVMUInt32Vector(uiv), isNotValidForMultiTargetGlobal);
5980 } else if (Type::Equal(constType, AtomicType::UniformFloat) || Type::Equal(constType, AtomicType::VaryingFloat)) {
5981 float fv[ISPC_MAX_NVEC];
5982 cExpr->GetValues(fv, constType->IsVaryingType());
5983 if (constType->IsUniformType())
5984 return std::pair<llvm::Constant *, bool>(LLVMFloat(fv[0]), isNotValidForMultiTargetGlobal);
5985 else
5986 return std::pair<llvm::Constant *, bool>(LLVMFloatVector(fv), isNotValidForMultiTargetGlobal);
5987 } else if (Type::Equal(constType, AtomicType::UniformInt64) || Type::Equal(constType, AtomicType::VaryingInt64)) {
5988 int64_t iv[ISPC_MAX_NVEC];
5989 cExpr->GetValues(iv, constType->IsVaryingType());
5990 if (constType->IsUniformType())
5991 return std::pair<llvm::Constant *, bool>(LLVMInt64(iv[0]), isNotValidForMultiTargetGlobal);
5992 else
5993 return std::pair<llvm::Constant *, bool>(LLVMInt64Vector(iv), isNotValidForMultiTargetGlobal);
5994 } else if (Type::Equal(constType, AtomicType::UniformUInt64) || Type::Equal(constType, AtomicType::VaryingUInt64)) {
5995 uint64_t uiv[ISPC_MAX_NVEC];
5996 cExpr->GetValues(uiv, constType->IsVaryingType());
5997 if (constType->IsUniformType())
5998 return std::pair<llvm::Constant *, bool>(LLVMUInt64(uiv[0]), isNotValidForMultiTargetGlobal);
5999 else
6000 return std::pair<llvm::Constant *, bool>(LLVMUInt64Vector(uiv), isNotValidForMultiTargetGlobal);
6001 } else if (Type::Equal(constType, AtomicType::UniformDouble) || Type::Equal(constType, AtomicType::VaryingDouble)) {
6002 double dv[ISPC_MAX_NVEC];
6003 cExpr->GetValues(dv, constType->IsVaryingType());
6004 if (constType->IsUniformType())
6005 return std::pair<llvm::Constant *, bool>(LLVMDouble(dv[0]), isNotValidForMultiTargetGlobal);
6006 else
6007 return std::pair<llvm::Constant *, bool>(LLVMDoubleVector(dv), isNotValidForMultiTargetGlobal);
6008 } else if (CastType<PointerType>(constType) != NULL) {
6009 // The only time we should get here is if we have an integer '0'
6010 // constant that should be turned into a NULL pointer of the
6011 // appropriate type.
6012 llvm::Type *llvmType = constType->LLVMType(g->ctx);
6013 if (llvmType == NULL) {
6014 AssertPos(pos, m->errorCount > 0);
6015 return std::pair<llvm::Constant *, bool>(NULL, false);
6016 }
6017
6018 int64_t iv[ISPC_MAX_NVEC];
6019 cExpr->GetValues(iv, constType->IsVaryingType());
6020 for (int i = 0; i < cExpr->Count(); ++i)
6021 if (iv[i] != 0)
6022 // We'll issue an error about this later--trying to assign
6023 // a constant int to a pointer, without a typecast.
6024 return std::pair<llvm::Constant *, bool>(NULL, false);
6025
6026 return std::pair<llvm::Constant *, bool>(llvm::Constant::getNullValue(llvmType),
6027 isNotValidForMultiTargetGlobal);
6028 } else {
6029 Debug(pos, "Unable to handle type \"%s\" in ConstExpr::GetConstant().", constType->GetString().c_str());
6030 return std::pair<llvm::Constant *, bool>(NULL, isNotValidForMultiTargetGlobal);
6031 }
6032 }
6033
GetStorageConstant(const Type * constType) const6034 std::pair<llvm::Constant *, bool> ConstExpr::GetStorageConstant(const Type *constType) const {
6035 return lGetConstExprConstant(constType, this, true);
6036 }
GetConstant(const Type * constType) const6037 std::pair<llvm::Constant *, bool> ConstExpr::GetConstant(const Type *constType) const {
6038 return lGetConstExprConstant(constType, this, false);
6039 }
6040
Optimize()6041 Expr *ConstExpr::Optimize() { return this; }
6042
TypeCheck()6043 Expr *ConstExpr::TypeCheck() { return this; }
6044
EstimateCost() const6045 int ConstExpr::EstimateCost() const { return 0; }
6046
Print() const6047 void ConstExpr::Print() const {
6048 printf("[%s] (", GetType()->GetString().c_str());
6049 for (int i = 0; i < Count(); ++i) {
6050 switch (getBasicType()) {
6051 case AtomicType::TYPE_BOOL:
6052 printf("%s", boolVal[i] ? "true" : "false");
6053 break;
6054 case AtomicType::TYPE_INT8:
6055 printf("%d", (int)int8Val[i]);
6056 break;
6057 case AtomicType::TYPE_UINT8:
6058 printf("%u", (int)uint8Val[i]);
6059 break;
6060 case AtomicType::TYPE_INT16:
6061 printf("%d", (int)int16Val[i]);
6062 break;
6063 case AtomicType::TYPE_UINT16:
6064 printf("%u", (int)uint16Val[i]);
6065 break;
6066 case AtomicType::TYPE_INT32:
6067 printf("%d", int32Val[i]);
6068 break;
6069 case AtomicType::TYPE_UINT32:
6070 printf("%u", uint32Val[i]);
6071 break;
6072 case AtomicType::TYPE_FLOAT:
6073 printf("%f", floatVal[i]);
6074 break;
6075 case AtomicType::TYPE_INT64:
6076 printf("%" PRId64, int64Val[i]);
6077 break;
6078 case AtomicType::TYPE_UINT64:
6079 printf("%" PRIu64, uint64Val[i]);
6080 break;
6081 case AtomicType::TYPE_DOUBLE:
6082 printf("%f", doubleVal[i]);
6083 break;
6084 default:
6085 FATAL("unimplemented const type");
6086 }
6087 if (i != Count() - 1)
6088 printf(", ");
6089 }
6090 printf(")");
6091 pos.Print();
6092 }
6093
6094 ///////////////////////////////////////////////////////////////////////////
6095 // TypeCastExpr
6096
TypeCastExpr(const Type * t,Expr * e,SourcePos p)6097 TypeCastExpr::TypeCastExpr(const Type *t, Expr *e, SourcePos p) : Expr(p, TypeCastExprID) {
6098 type = t;
6099 expr = e;
6100 }
6101
6102 /** Handle all the grungy details of type conversion between atomic types.
6103 Given an input value in exprVal of type fromType, convert it to the
6104 llvm::Value with type toType.
6105 */
lTypeConvAtomic(FunctionEmitContext * ctx,llvm::Value * exprVal,const AtomicType * toType,const AtomicType * fromType,SourcePos pos)6106 static llvm::Value *lTypeConvAtomic(FunctionEmitContext *ctx, llvm::Value *exprVal, const AtomicType *toType,
6107 const AtomicType *fromType, SourcePos pos) {
6108 llvm::Value *cast = NULL;
6109
6110 std::string opName = exprVal->getName().str();
6111 switch (toType->basicType) {
6112 case AtomicType::TYPE_BOOL:
6113 opName += "_to_bool";
6114 break;
6115 case AtomicType::TYPE_INT8:
6116 opName += "_to_int8";
6117 break;
6118 case AtomicType::TYPE_UINT8:
6119 opName += "_to_uint8";
6120 break;
6121 case AtomicType::TYPE_INT16:
6122 opName += "_to_int16";
6123 break;
6124 case AtomicType::TYPE_UINT16:
6125 opName += "_to_uint16";
6126 break;
6127 case AtomicType::TYPE_INT32:
6128 opName += "_to_int32";
6129 break;
6130 case AtomicType::TYPE_UINT32:
6131 opName += "_to_uint32";
6132 break;
6133 case AtomicType::TYPE_INT64:
6134 opName += "_to_int64";
6135 break;
6136 case AtomicType::TYPE_UINT64:
6137 opName += "_to_uint64";
6138 break;
6139 case AtomicType::TYPE_FLOAT:
6140 opName += "_to_float";
6141 break;
6142 case AtomicType::TYPE_DOUBLE:
6143 opName += "_to_double";
6144 break;
6145 default:
6146 FATAL("Unimplemented");
6147 }
6148 const char *cOpName = opName.c_str();
6149
6150 switch (toType->basicType) {
6151 case AtomicType::TYPE_FLOAT: {
6152 llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::FloatType : LLVMTypes::FloatVectorType;
6153 switch (fromType->basicType) {
6154 case AtomicType::TYPE_BOOL:
6155 if (fromType->IsVaryingType())
6156 // If we have a bool vector of non-i1 elements, first
6157 // truncate down to a single bit.
6158 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6159 // And then do an unisgned int->float cast
6160 cast = ctx->CastInst(llvm::Instruction::UIToFP, // unsigned int
6161 exprVal, targetType, cOpName);
6162 break;
6163 case AtomicType::TYPE_INT8:
6164 case AtomicType::TYPE_INT16:
6165 case AtomicType::TYPE_INT32:
6166 case AtomicType::TYPE_INT64:
6167 cast = ctx->CastInst(llvm::Instruction::SIToFP, // signed int to float
6168 exprVal, targetType, cOpName);
6169 break;
6170 case AtomicType::TYPE_UINT8:
6171 case AtomicType::TYPE_UINT16:
6172 case AtomicType::TYPE_UINT32:
6173 case AtomicType::TYPE_UINT64:
6174 // float -> uint32 is the only conversion for which a signed cvt
6175 // exists which cannot be used for unsigned.
6176 // This is a problem for non-neon, non-avx512 targets from among
6177 // arm/x86 cpu targets. Revisit for genx/wasm.
6178 if (fromType->IsVaryingType() && (g->target->warnFtoU32IsExpensive() == true) &&
6179 (fromType->basicType == AtomicType::TYPE_UINT32))
6180 PerformanceWarning(pos, "Conversion from unsigned int to float is slow. "
6181 "Use \"int\" if possible");
6182 cast = ctx->CastInst(llvm::Instruction::UIToFP, // unsigned int to float
6183 exprVal, targetType, cOpName);
6184 break;
6185 case AtomicType::TYPE_FLOAT:
6186 // No-op cast.
6187 cast = exprVal;
6188 break;
6189 case AtomicType::TYPE_DOUBLE:
6190 cast = ctx->FPCastInst(exprVal, targetType, cOpName);
6191 break;
6192 default:
6193 FATAL("unimplemented");
6194 }
6195 break;
6196 }
6197 case AtomicType::TYPE_DOUBLE: {
6198 llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::DoubleType : LLVMTypes::DoubleVectorType;
6199 switch (fromType->basicType) {
6200 case AtomicType::TYPE_BOOL:
6201 if (fromType->IsVaryingType())
6202 // truncate bool vector values to i1s if necessary.
6203 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6204 cast = ctx->CastInst(llvm::Instruction::UIToFP, // unsigned int to double
6205 exprVal, targetType, cOpName);
6206 break;
6207 case AtomicType::TYPE_INT8:
6208 case AtomicType::TYPE_INT16:
6209 case AtomicType::TYPE_INT32:
6210 case AtomicType::TYPE_INT64:
6211 cast = ctx->CastInst(llvm::Instruction::SIToFP, // signed int
6212 exprVal, targetType, cOpName);
6213 break;
6214 case AtomicType::TYPE_UINT8:
6215 case AtomicType::TYPE_UINT16:
6216 case AtomicType::TYPE_UINT32:
6217 case AtomicType::TYPE_UINT64:
6218 cast = ctx->CastInst(llvm::Instruction::UIToFP, // unsigned int
6219 exprVal, targetType, cOpName);
6220 break;
6221 case AtomicType::TYPE_FLOAT:
6222 cast = ctx->FPCastInst(exprVal, targetType, cOpName);
6223 break;
6224 case AtomicType::TYPE_DOUBLE:
6225 cast = exprVal;
6226 break;
6227 default:
6228 FATAL("unimplemented");
6229 }
6230 break;
6231 }
6232 case AtomicType::TYPE_INT8: {
6233 llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::Int8Type : LLVMTypes::Int8VectorType;
6234 switch (fromType->basicType) {
6235 case AtomicType::TYPE_BOOL:
6236 if (fromType->IsVaryingType())
6237 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6238 cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6239 break;
6240 case AtomicType::TYPE_INT8:
6241 case AtomicType::TYPE_UINT8:
6242 cast = exprVal;
6243 break;
6244 case AtomicType::TYPE_INT16:
6245 case AtomicType::TYPE_UINT16:
6246 case AtomicType::TYPE_INT32:
6247 case AtomicType::TYPE_UINT32:
6248 case AtomicType::TYPE_INT64:
6249 case AtomicType::TYPE_UINT64:
6250 cast = ctx->TruncInst(exprVal, targetType, cOpName);
6251 break;
6252 case AtomicType::TYPE_FLOAT:
6253 case AtomicType::TYPE_DOUBLE:
6254 cast = ctx->CastInst(llvm::Instruction::FPToSI, // signed int
6255 exprVal, targetType, cOpName);
6256 break;
6257 default:
6258 FATAL("unimplemented");
6259 }
6260 break;
6261 }
6262 case AtomicType::TYPE_UINT8: {
6263 llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::Int8Type : LLVMTypes::Int8VectorType;
6264 switch (fromType->basicType) {
6265 case AtomicType::TYPE_BOOL:
6266 if (fromType->IsVaryingType())
6267 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6268 cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6269 break;
6270 case AtomicType::TYPE_INT8:
6271 case AtomicType::TYPE_UINT8:
6272 cast = exprVal;
6273 break;
6274 case AtomicType::TYPE_INT16:
6275 case AtomicType::TYPE_UINT16:
6276 case AtomicType::TYPE_INT32:
6277 case AtomicType::TYPE_UINT32:
6278 case AtomicType::TYPE_INT64:
6279 case AtomicType::TYPE_UINT64:
6280 cast = ctx->TruncInst(exprVal, targetType, cOpName);
6281 break;
6282 case AtomicType::TYPE_FLOAT:
6283 if (fromType->IsVaryingType())
6284 PerformanceWarning(pos, "Conversion from float to unsigned int is slow. "
6285 "Use \"int\" if possible");
6286 cast = ctx->CastInst(llvm::Instruction::FPToUI, // unsigned int
6287 exprVal, targetType, cOpName);
6288 break;
6289 case AtomicType::TYPE_DOUBLE:
6290 if (fromType->IsVaryingType())
6291 PerformanceWarning(pos, "Conversion from double to unsigned int is slow. "
6292 "Use \"int\" if possible");
6293 cast = ctx->CastInst(llvm::Instruction::FPToUI, // unsigned int
6294 exprVal, targetType, cOpName);
6295 break;
6296 default:
6297 FATAL("unimplemented");
6298 }
6299 break;
6300 }
6301 case AtomicType::TYPE_INT16: {
6302 llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::Int16Type : LLVMTypes::Int16VectorType;
6303 switch (fromType->basicType) {
6304 case AtomicType::TYPE_BOOL:
6305 if (fromType->IsVaryingType())
6306 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6307 cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6308 break;
6309 case AtomicType::TYPE_INT8:
6310 cast = ctx->SExtInst(exprVal, targetType, cOpName);
6311 break;
6312 case AtomicType::TYPE_UINT8:
6313 cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6314 break;
6315 case AtomicType::TYPE_INT16:
6316 case AtomicType::TYPE_UINT16:
6317 cast = exprVal;
6318 break;
6319 case AtomicType::TYPE_INT32:
6320 case AtomicType::TYPE_UINT32:
6321 case AtomicType::TYPE_INT64:
6322 case AtomicType::TYPE_UINT64:
6323 cast = ctx->TruncInst(exprVal, targetType, cOpName);
6324 break;
6325 case AtomicType::TYPE_FLOAT:
6326 case AtomicType::TYPE_DOUBLE:
6327 cast = ctx->CastInst(llvm::Instruction::FPToSI, // signed int
6328 exprVal, targetType, cOpName);
6329 break;
6330 default:
6331 FATAL("unimplemented");
6332 }
6333 break;
6334 }
6335 case AtomicType::TYPE_UINT16: {
6336 llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::Int16Type : LLVMTypes::Int16VectorType;
6337 switch (fromType->basicType) {
6338 case AtomicType::TYPE_BOOL:
6339 if (fromType->IsVaryingType())
6340 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6341 cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6342 break;
6343 case AtomicType::TYPE_INT8:
6344 cast = ctx->SExtInst(exprVal, targetType, cOpName);
6345 break;
6346 case AtomicType::TYPE_UINT8:
6347 cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6348 break;
6349 case AtomicType::TYPE_INT16:
6350 case AtomicType::TYPE_UINT16:
6351 cast = exprVal;
6352 break;
6353 case AtomicType::TYPE_FLOAT:
6354 if (fromType->IsVaryingType())
6355 PerformanceWarning(pos, "Conversion from float to unsigned int is slow. "
6356 "Use \"int\" if possible");
6357 cast = ctx->CastInst(llvm::Instruction::FPToUI, // unsigned int
6358 exprVal, targetType, cOpName);
6359 break;
6360 case AtomicType::TYPE_INT32:
6361 case AtomicType::TYPE_UINT32:
6362 case AtomicType::TYPE_INT64:
6363 case AtomicType::TYPE_UINT64:
6364 cast = ctx->TruncInst(exprVal, targetType, cOpName);
6365 break;
6366 case AtomicType::TYPE_DOUBLE:
6367 if (fromType->IsVaryingType())
6368 PerformanceWarning(pos, "Conversion from double to unsigned int is slow. "
6369 "Use \"int\" if possible");
6370 cast = ctx->CastInst(llvm::Instruction::FPToUI, // unsigned int
6371 exprVal, targetType, cOpName);
6372 break;
6373 default:
6374 FATAL("unimplemented");
6375 }
6376 break;
6377 }
6378 case AtomicType::TYPE_INT32: {
6379 llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::Int32Type : LLVMTypes::Int32VectorType;
6380 switch (fromType->basicType) {
6381 case AtomicType::TYPE_BOOL:
6382 if (fromType->IsVaryingType())
6383 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6384 cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6385 break;
6386 case AtomicType::TYPE_INT8:
6387 case AtomicType::TYPE_INT16:
6388 cast = ctx->SExtInst(exprVal, targetType, cOpName);
6389 break;
6390 case AtomicType::TYPE_UINT8:
6391 case AtomicType::TYPE_UINT16:
6392 cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6393 break;
6394 case AtomicType::TYPE_INT32:
6395 case AtomicType::TYPE_UINT32:
6396 cast = exprVal;
6397 break;
6398 case AtomicType::TYPE_INT64:
6399 case AtomicType::TYPE_UINT64:
6400 cast = ctx->TruncInst(exprVal, targetType, cOpName);
6401 break;
6402 case AtomicType::TYPE_FLOAT:
6403 case AtomicType::TYPE_DOUBLE:
6404 cast = ctx->CastInst(llvm::Instruction::FPToSI, // signed int
6405 exprVal, targetType, cOpName);
6406 break;
6407 default:
6408 FATAL("unimplemented");
6409 }
6410 break;
6411 }
6412 case AtomicType::TYPE_UINT32: {
6413 llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::Int32Type : LLVMTypes::Int32VectorType;
6414 switch (fromType->basicType) {
6415 case AtomicType::TYPE_BOOL:
6416 if (fromType->IsVaryingType())
6417 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6418 cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6419 break;
6420 case AtomicType::TYPE_INT8:
6421 case AtomicType::TYPE_INT16:
6422 cast = ctx->SExtInst(exprVal, targetType, cOpName);
6423 break;
6424 case AtomicType::TYPE_UINT8:
6425 case AtomicType::TYPE_UINT16:
6426 cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6427 break;
6428 case AtomicType::TYPE_INT32:
6429 case AtomicType::TYPE_UINT32:
6430 cast = exprVal;
6431 break;
6432 case AtomicType::TYPE_FLOAT:
6433 if (fromType->IsVaryingType())
6434 PerformanceWarning(pos, "Conversion from float to unsigned int is slow. "
6435 "Use \"int\" if possible");
6436 cast = ctx->CastInst(llvm::Instruction::FPToUI, // unsigned int
6437 exprVal, targetType, cOpName);
6438 break;
6439 case AtomicType::TYPE_INT64:
6440 case AtomicType::TYPE_UINT64:
6441 cast = ctx->TruncInst(exprVal, targetType, cOpName);
6442 break;
6443 case AtomicType::TYPE_DOUBLE:
6444 if (fromType->IsVaryingType())
6445 PerformanceWarning(pos, "Conversion from double to unsigned int is slow. "
6446 "Use \"int\" if possible");
6447 cast = ctx->CastInst(llvm::Instruction::FPToUI, // unsigned int
6448 exprVal, targetType, cOpName);
6449 break;
6450 default:
6451 FATAL("unimplemented");
6452 }
6453 break;
6454 }
6455 case AtomicType::TYPE_INT64: {
6456 llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::Int64Type : LLVMTypes::Int64VectorType;
6457 switch (fromType->basicType) {
6458 case AtomicType::TYPE_BOOL:
6459 if (fromType->IsVaryingType())
6460 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6461 cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6462 break;
6463 case AtomicType::TYPE_INT8:
6464 case AtomicType::TYPE_INT16:
6465 case AtomicType::TYPE_INT32:
6466 cast = ctx->SExtInst(exprVal, targetType, cOpName);
6467 break;
6468 case AtomicType::TYPE_UINT8:
6469 case AtomicType::TYPE_UINT16:
6470 case AtomicType::TYPE_UINT32:
6471 cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6472 break;
6473 case AtomicType::TYPE_INT64:
6474 case AtomicType::TYPE_UINT64:
6475 cast = exprVal;
6476 break;
6477 case AtomicType::TYPE_FLOAT:
6478 case AtomicType::TYPE_DOUBLE:
6479 cast = ctx->CastInst(llvm::Instruction::FPToSI, // signed int
6480 exprVal, targetType, cOpName);
6481 break;
6482 default:
6483 FATAL("unimplemented");
6484 }
6485 break;
6486 }
6487 case AtomicType::TYPE_UINT64: {
6488 llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::Int64Type : LLVMTypes::Int64VectorType;
6489 switch (fromType->basicType) {
6490 case AtomicType::TYPE_BOOL:
6491 if (fromType->IsVaryingType())
6492 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6493 cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6494 break;
6495 case AtomicType::TYPE_INT8:
6496 case AtomicType::TYPE_INT16:
6497 case AtomicType::TYPE_INT32:
6498 cast = ctx->SExtInst(exprVal, targetType, cOpName);
6499 break;
6500 case AtomicType::TYPE_UINT8:
6501 case AtomicType::TYPE_UINT16:
6502 case AtomicType::TYPE_UINT32:
6503 cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6504 break;
6505 case AtomicType::TYPE_FLOAT:
6506 if (fromType->IsVaryingType())
6507 PerformanceWarning(pos, "Conversion from float to unsigned int64 is slow. "
6508 "Use \"int64\" if possible");
6509 cast = ctx->CastInst(llvm::Instruction::FPToUI, // signed int
6510 exprVal, targetType, cOpName);
6511 break;
6512 case AtomicType::TYPE_INT64:
6513 case AtomicType::TYPE_UINT64:
6514 cast = exprVal;
6515 break;
6516 case AtomicType::TYPE_DOUBLE:
6517 if (fromType->IsVaryingType())
6518 PerformanceWarning(pos, "Conversion from double to unsigned int64 is slow. "
6519 "Use \"int64\" if possible");
6520 cast = ctx->CastInst(llvm::Instruction::FPToUI, // signed int
6521 exprVal, targetType, cOpName);
6522 break;
6523 default:
6524 FATAL("unimplemented");
6525 }
6526 break;
6527 }
6528 case AtomicType::TYPE_BOOL: {
6529 switch (fromType->basicType) {
6530 case AtomicType::TYPE_BOOL:
6531 if (fromType->IsVaryingType()) {
6532 // truncate bool vector values to i1s if necessary.
6533 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6534 }
6535 cast = exprVal;
6536 break;
6537 case AtomicType::TYPE_INT8:
6538 case AtomicType::TYPE_UINT8: {
6539 llvm::Value *zero =
6540 fromType->IsUniformType() ? (llvm::Value *)LLVMInt8(0) : (llvm::Value *)LLVMInt8Vector((int8_t)0);
6541 cast = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE, exprVal, zero, cOpName);
6542 break;
6543 }
6544 case AtomicType::TYPE_INT16:
6545 case AtomicType::TYPE_UINT16: {
6546 llvm::Value *zero =
6547 fromType->IsUniformType() ? (llvm::Value *)LLVMInt16(0) : (llvm::Value *)LLVMInt16Vector((int16_t)0);
6548 cast = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE, exprVal, zero, cOpName);
6549 break;
6550 }
6551 case AtomicType::TYPE_INT32:
6552 case AtomicType::TYPE_UINT32: {
6553 llvm::Value *zero =
6554 fromType->IsUniformType() ? (llvm::Value *)LLVMInt32(0) : (llvm::Value *)LLVMInt32Vector(0);
6555 cast = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE, exprVal, zero, cOpName);
6556 break;
6557 }
6558 case AtomicType::TYPE_FLOAT: {
6559 llvm::Value *zero =
6560 fromType->IsUniformType() ? (llvm::Value *)LLVMFloat(0.f) : (llvm::Value *)LLVMFloatVector(0.f);
6561 cast = ctx->CmpInst(llvm::Instruction::FCmp, llvm::CmpInst::FCMP_ONE, exprVal, zero, cOpName);
6562 break;
6563 }
6564 case AtomicType::TYPE_INT64:
6565 case AtomicType::TYPE_UINT64: {
6566 llvm::Value *zero =
6567 fromType->IsUniformType() ? (llvm::Value *)LLVMInt64(0) : (llvm::Value *)LLVMInt64Vector((int64_t)0);
6568 cast = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE, exprVal, zero, cOpName);
6569 break;
6570 }
6571 case AtomicType::TYPE_DOUBLE: {
6572 llvm::Value *zero =
6573 fromType->IsUniformType() ? (llvm::Value *)LLVMDouble(0.) : (llvm::Value *)LLVMDoubleVector(0.);
6574 cast = ctx->CmpInst(llvm::Instruction::FCmp, llvm::CmpInst::FCMP_ONE, exprVal, zero, cOpName);
6575 break;
6576 }
6577 default:
6578 FATAL("unimplemented");
6579 }
6580
6581 if (fromType->IsUniformType()) {
6582 if (toType->IsVaryingType() && LLVMTypes::BoolVectorType != LLVMTypes::Int1VectorType) {
6583 // extend out to an bool as an i8/i16/i32 from the i1 here.
6584 // Then we'll turn that into a vector below, the way it
6585 // does for everyone else...
6586 Assert(cast);
6587 cast = ctx->SwitchBoolSize(cast, LLVMTypes::BoolVectorType->getElementType(),
6588 llvm::Twine(cast->getName()) + "to_i_bool");
6589 }
6590 } else {
6591 // fromType->IsVaryingType())
6592 cast = ctx->I1VecToBoolVec(cast);
6593 }
6594 break;
6595 }
6596 default:
6597 FATAL("unimplemented");
6598 }
6599
6600 // If we also want to go from uniform to varying, replicate out the
6601 // value across the vector elements..
6602 if (toType->IsVaryingType() && fromType->IsUniformType())
6603 return ctx->SmearUniform(cast);
6604 else
6605 return cast;
6606 }
6607
6608 // FIXME: fold this into the FunctionEmitContext::SmearUniform() method?
6609
6610 /** Converts the given value of the given type to be the varying
6611 equivalent, returning the resulting value.
6612 */
lUniformValueToVarying(FunctionEmitContext * ctx,llvm::Value * value,const Type * type,SourcePos pos)6613 static llvm::Value *lUniformValueToVarying(FunctionEmitContext *ctx, llvm::Value *value, const Type *type,
6614 SourcePos pos) {
6615 // nothing to do if it's already varying
6616 if (type->IsVaryingType())
6617 return value;
6618
6619 // for structs/arrays/vectors, just recursively make their elements
6620 // varying (if needed) and populate the return value.
6621 const CollectionType *collectionType = CastType<CollectionType>(type);
6622 if (collectionType != NULL) {
6623 llvm::Type *llvmType = type->GetAsVaryingType()->LLVMStorageType(g->ctx);
6624 llvm::Value *retValue = llvm::UndefValue::get(llvmType);
6625
6626 const StructType *structType = CastType<StructType>(type->GetAsVaryingType());
6627
6628 for (int i = 0; i < collectionType->GetElementCount(); ++i) {
6629 llvm::Value *v = ctx->ExtractInst(value, i, "get_element");
6630 // If struct has "bound uniform" member, we don't need to cast it to varying
6631 if (!(structType != NULL && structType->GetElementType(i)->IsUniformType())) {
6632 const Type *elemType = collectionType->GetElementType(i);
6633 // If member is a uniform bool, it needs to be truncated to i1 since
6634 // uniform bool in IR is i1 and i8 in struct
6635 // Consider switching to just a broadcast for bool
6636 if ((elemType->IsBoolType()) && (CastType<AtomicType>(elemType) != NULL)) {
6637 v = ctx->TruncInst(v, LLVMTypes::BoolType);
6638 }
6639 v = lUniformValueToVarying(ctx, v, elemType, pos);
6640 // If the extracted element if bool and varying needs to be
6641 // converted back to i8 vector to insert into varying struct.
6642 if ((elemType->IsBoolType()) && (CastType<AtomicType>(elemType) != NULL)) {
6643 v = ctx->SwitchBoolSize(v, LLVMTypes::BoolVectorStorageType);
6644 }
6645 }
6646 retValue = ctx->InsertInst(retValue, v, i, "set_element");
6647 }
6648 return retValue;
6649 }
6650
6651 // Otherwise we must have a uniform atomic or pointer type, so smear
6652 // its value across the vector lanes.
6653 if (CastType<AtomicType>(type)) {
6654 return lTypeConvAtomic(ctx, value, CastType<AtomicType>(type->GetAsVaryingType()), CastType<AtomicType>(type),
6655 pos);
6656 }
6657
6658 Assert(CastType<PointerType>(type) != NULL);
6659 return ctx->SmearUniform(value);
6660 }
6661
HasAmbiguousVariability(std::vector<const Expr * > & warn) const6662 bool TypeCastExpr::HasAmbiguousVariability(std::vector<const Expr *> &warn) const {
6663
6664 if (expr == NULL)
6665 return false;
6666
6667 const Type *toType = type, *fromType = expr->GetType();
6668 if (toType == NULL || fromType == NULL)
6669 return false;
6670
6671 if (toType->HasUnboundVariability() && fromType->IsUniformType()) {
6672 warn.push_back(this);
6673 return true;
6674 }
6675
6676 return false;
6677 }
6678
PrintAmbiguousVariability() const6679 void TypeCastExpr::PrintAmbiguousVariability() const {
6680 Warning(pos,
6681 "Typecasting to type \"%s\" (variability not specified) "
6682 "from \"uniform\" type \"%s\" results in \"uniform\" variability.\n"
6683 "In the context of function argument it may lead to unexpected behavior. "
6684 "Casting to \"%s\" is recommended.",
6685 (type->GetString()).c_str(), ((expr->GetType())->GetString()).c_str(),
6686 (type->GetAsUniformType()->GetString()).c_str());
6687 }
6688
GetValue(FunctionEmitContext * ctx) const6689 llvm::Value *TypeCastExpr::GetValue(FunctionEmitContext *ctx) const {
6690 if (!expr)
6691 return NULL;
6692
6693 ctx->SetDebugPos(pos);
6694 const Type *toType = GetType(), *fromType = expr->GetType();
6695 if (toType == NULL || fromType == NULL) {
6696 AssertPos(pos, m->errorCount > 0);
6697 return NULL;
6698 }
6699
6700 if (toType->IsVoidType()) {
6701 // emit the code for the expression in case it has side-effects but
6702 // then we're done.
6703 (void)expr->GetValue(ctx);
6704 return NULL;
6705 }
6706
6707 const PointerType *fromPointerType = CastType<PointerType>(fromType);
6708 const PointerType *toPointerType = CastType<PointerType>(toType);
6709 const ArrayType *toArrayType = CastType<ArrayType>(toType);
6710 const ArrayType *fromArrayType = CastType<ArrayType>(fromType);
6711 if (fromPointerType != NULL) {
6712 if (toArrayType != NULL) {
6713 return expr->GetValue(ctx);
6714 } else if (toPointerType != NULL) {
6715 llvm::Value *value = expr->GetValue(ctx);
6716 if (value == NULL)
6717 return NULL;
6718
6719 if (fromPointerType->IsSlice() == false && toPointerType->IsSlice() == true) {
6720 // Convert from a non-slice pointer to a slice pointer by
6721 // creating a slice pointer structure with zero offsets.
6722 if (fromPointerType->IsUniformType())
6723 value = ctx->MakeSlicePointer(value, LLVMInt32(0));
6724 else
6725 value = ctx->MakeSlicePointer(value, LLVMInt32Vector(0));
6726
6727 // FIXME: avoid error from unnecessary bitcast when all we
6728 // need to do is the slice conversion and don't need to
6729 // also do unif->varying conversions. But this is really
6730 // ugly logic.
6731 if (value->getType() == toType->LLVMType(g->ctx))
6732 return value;
6733 }
6734
6735 if (fromType->IsUniformType() && toType->IsUniformType())
6736 // bitcast to the actual pointer type
6737 return ctx->BitCastInst(value, toType->LLVMType(g->ctx));
6738 else if (fromType->IsVaryingType() && toType->IsVaryingType()) {
6739 // both are vectors of ints already, nothing to do at the IR
6740 // level
6741 return value;
6742 } else {
6743 // Uniform -> varying pointer conversion
6744 AssertPos(pos, fromType->IsUniformType() && toType->IsVaryingType());
6745 if (fromPointerType->IsSlice()) {
6746 // For slice pointers, we need to smear out both the
6747 // pointer and the offset vector
6748 AssertPos(pos, toPointerType->IsSlice());
6749 llvm::Value *ptr = ctx->ExtractInst(value, 0);
6750 llvm::Value *offset = ctx->ExtractInst(value, 1);
6751 ptr = ctx->PtrToIntInst(ptr);
6752 ptr = ctx->SmearUniform(ptr);
6753 offset = ctx->SmearUniform(offset);
6754 return ctx->MakeSlicePointer(ptr, offset);
6755 } else {
6756 // Otherwise we just bitcast it to an int and smear it
6757 // out to a vector
6758 value = ctx->PtrToIntInst(value);
6759 return ctx->SmearUniform(value);
6760 }
6761 }
6762 } else {
6763 AssertPos(pos, CastType<AtomicType>(toType) != NULL);
6764 if (toType->IsBoolType()) {
6765 // convert pointer to bool
6766 llvm::Type *lfu = fromType->GetAsUniformType()->LLVMType(g->ctx);
6767 llvm::PointerType *llvmFromUnifType = llvm::dyn_cast<llvm::PointerType>(lfu);
6768
6769 llvm::Value *nullPtrValue = llvm::ConstantPointerNull::get(llvmFromUnifType);
6770 if (fromType->IsVaryingType())
6771 nullPtrValue = ctx->SmearUniform(nullPtrValue);
6772
6773 llvm::Value *exprVal = expr->GetValue(ctx);
6774 llvm::Value *cmp =
6775 ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE, exprVal, nullPtrValue, "ptr_ne_NULL");
6776
6777 if (toType->IsVaryingType()) {
6778 if (fromType->IsUniformType())
6779 cmp = ctx->SmearUniform(cmp);
6780 cmp = ctx->I1VecToBoolVec(cmp);
6781 }
6782
6783 return cmp;
6784 } else {
6785 // ptr -> int
6786 llvm::Value *value = expr->GetValue(ctx);
6787 if (value == NULL)
6788 return NULL;
6789
6790 if (toType->IsVaryingType() && fromType->IsUniformType())
6791 value = ctx->SmearUniform(value);
6792
6793 llvm::Type *llvmToType = toType->LLVMType(g->ctx);
6794 if (llvmToType == NULL)
6795 return NULL;
6796 return ctx->PtrToIntInst(value, llvmToType, "ptr_typecast");
6797 }
6798 }
6799 }
6800
6801 if (Type::EqualIgnoringConst(toType, fromType))
6802 // There's nothing to do, just return the value. (LLVM's type
6803 // system doesn't worry about constiness.)
6804 return expr->GetValue(ctx);
6805
6806 if (fromArrayType != NULL && toPointerType != NULL) {
6807 // implicit array to pointer to first element
6808 Expr *arrayAsPtr = lArrayToPointer(expr);
6809 if (Type::EqualIgnoringConst(arrayAsPtr->GetType(), toPointerType) == false) {
6810 AssertPos(pos,
6811 PointerType::IsVoidPointer(toPointerType) ||
6812 Type::EqualIgnoringConst(arrayAsPtr->GetType()->GetAsVaryingType(), toPointerType) == true);
6813 arrayAsPtr = new TypeCastExpr(toPointerType, arrayAsPtr, pos);
6814 arrayAsPtr = ::TypeCheck(arrayAsPtr);
6815 AssertPos(pos, arrayAsPtr != NULL);
6816 arrayAsPtr = ::Optimize(arrayAsPtr);
6817 AssertPos(pos, arrayAsPtr != NULL);
6818 }
6819 AssertPos(pos, Type::EqualIgnoringConst(arrayAsPtr->GetType(), toPointerType));
6820 return arrayAsPtr->GetValue(ctx);
6821 }
6822
6823 // This also should be caught during typechecking
6824 AssertPos(pos, !(toType->IsUniformType() && fromType->IsVaryingType()));
6825
6826 if (toArrayType != NULL && fromArrayType != NULL) {
6827 // cast array pointer from [n x foo] to [0 x foo] if needed to be able
6828 // to pass to a function that takes an unsized array as a parameter
6829 if (toArrayType->GetElementCount() != 0 && (toArrayType->GetElementCount() != fromArrayType->GetElementCount()))
6830 Warning(pos, "Type-converting array of length %d to length %d", fromArrayType->GetElementCount(),
6831 toArrayType->GetElementCount());
6832 AssertPos(pos, Type::EqualIgnoringConst(toArrayType->GetBaseType(), fromArrayType->GetBaseType()));
6833 llvm::Value *v = expr->GetValue(ctx);
6834 llvm::Type *ptype = toType->LLVMType(g->ctx);
6835 return ctx->BitCastInst(v, ptype); //, "array_cast_0size");
6836 }
6837
6838 const ReferenceType *toReference = CastType<ReferenceType>(toType);
6839 const ReferenceType *fromReference = CastType<ReferenceType>(fromType);
6840 if (toReference && fromReference) {
6841 const Type *toTarget = toReference->GetReferenceTarget();
6842 const Type *fromTarget = fromReference->GetReferenceTarget();
6843
6844 const ArrayType *toArray = CastType<ArrayType>(toTarget);
6845 const ArrayType *fromArray = CastType<ArrayType>(fromTarget);
6846 if (toArray && fromArray) {
6847 // cast array pointer from [n x foo] to [0 x foo] if needed to be able
6848 // to pass to a function that takes an unsized array as a parameter
6849 if (toArray->GetElementCount() != 0 && (toArray->GetElementCount() != fromArray->GetElementCount()))
6850 Warning(pos, "Type-converting array of length %d to length %d", fromArray->GetElementCount(),
6851 toArray->GetElementCount());
6852 AssertPos(pos, Type::EqualIgnoringConst(toArray->GetBaseType(), fromArray->GetBaseType()));
6853 llvm::Value *v = expr->GetValue(ctx);
6854 llvm::Type *ptype = toType->LLVMType(g->ctx);
6855 return ctx->BitCastInst(v, ptype); //, "array_cast_0size");
6856 }
6857
6858 // Just bitcast it. See Issue #721
6859 llvm::Value *value = expr->GetValue(ctx);
6860 return ctx->BitCastInst(value, toType->LLVMType(g->ctx), "refcast");
6861 }
6862
6863 const StructType *toStruct = CastType<StructType>(toType);
6864 const StructType *fromStruct = CastType<StructType>(fromType);
6865 if (toStruct && fromStruct) {
6866 // The only legal type conversions for structs are to go from a
6867 // uniform to a varying instance of the same struct type.
6868 AssertPos(pos, toStruct->IsVaryingType() && fromStruct->IsUniformType() &&
6869 Type::EqualIgnoringConst(toStruct, fromStruct->GetAsVaryingType()));
6870
6871 llvm::Value *origValue = expr->GetValue(ctx);
6872 if (!origValue)
6873 return NULL;
6874 return lUniformValueToVarying(ctx, origValue, fromType, pos);
6875 }
6876
6877 const VectorType *toVector = CastType<VectorType>(toType);
6878 const VectorType *fromVector = CastType<VectorType>(fromType);
6879 if (toVector && fromVector) {
6880 // this should be caught during typechecking
6881 AssertPos(pos, toVector->GetElementCount() == fromVector->GetElementCount());
6882
6883 llvm::Value *exprVal = expr->GetValue(ctx);
6884 if (!exprVal)
6885 return NULL;
6886
6887 // Emit instructions to do type conversion of each of the elements
6888 // of the vector.
6889 // FIXME: since uniform vectors are represented as
6890 // llvm::VectorTypes, we should just be able to issue the
6891 // corresponding vector type convert, which should be more
6892 // efficient by avoiding serialization!
6893 llvm::Value *cast = llvm::UndefValue::get(toType->LLVMStorageType(g->ctx));
6894 for (int i = 0; i < toVector->GetElementCount(); ++i) {
6895 llvm::Value *ei = ctx->ExtractInst(exprVal, i);
6896 llvm::Value *conv = lTypeConvAtomic(ctx, ei, toVector->GetElementType(), fromVector->GetElementType(), pos);
6897 if (!conv)
6898 return NULL;
6899 if ((toVector->GetElementType()->IsBoolType()) &&
6900 (CastType<AtomicType>(toVector->GetElementType()) != NULL)) {
6901 conv = ctx->SwitchBoolSize(conv, toVector->GetElementType()->LLVMStorageType(g->ctx));
6902 }
6903
6904 cast = ctx->InsertInst(cast, conv, i);
6905 }
6906 return cast;
6907 }
6908
6909 llvm::Value *exprVal = expr->GetValue(ctx);
6910 if (!exprVal)
6911 return NULL;
6912
6913 const EnumType *fromEnum = CastType<EnumType>(fromType);
6914 const EnumType *toEnum = CastType<EnumType>(toType);
6915 if (fromEnum)
6916 // treat it as an uint32 type for the below and all will be good.
6917 fromType = fromEnum->IsUniformType() ? AtomicType::UniformUInt32 : AtomicType::VaryingUInt32;
6918 if (toEnum)
6919 // treat it as an uint32 type for the below and all will be good.
6920 toType = toEnum->IsUniformType() ? AtomicType::UniformUInt32 : AtomicType::VaryingUInt32;
6921
6922 const AtomicType *fromAtomic = CastType<AtomicType>(fromType);
6923 // at this point, coming from an atomic type is all that's left...
6924 AssertPos(pos, fromAtomic != NULL);
6925
6926 if (toVector) {
6927 // scalar -> short vector conversion
6928 llvm::Value *conv = lTypeConvAtomic(ctx, exprVal, toVector->GetElementType(), fromAtomic, pos);
6929 if (!conv)
6930 return NULL;
6931
6932 llvm::Value *cast = NULL;
6933 llvm::Type *toTypeLLVM = toType->LLVMStorageType(g->ctx);
6934 if (llvm::isa<llvm::VectorType>(toTypeLLVM)) {
6935 // Example uniform float => uniform float<3>
6936 cast = ctx->BroadcastValue(conv, toTypeLLVM);
6937 } else if (llvm::isa<llvm::ArrayType>(toTypeLLVM)) {
6938 // Example varying float => varying float<3>
6939 cast = llvm::UndefValue::get(toType->LLVMStorageType(g->ctx));
6940 for (int i = 0; i < toVector->GetElementCount(); ++i) {
6941 if ((toVector->GetElementType()->IsBoolType()) &&
6942 (CastType<AtomicType>(toVector->GetElementType()) != NULL)) {
6943 conv = ctx->SwitchBoolSize(conv, toVector->GetElementType()->LLVMStorageType(g->ctx));
6944 }
6945 // Here's InsertInst produces InsertValueInst.
6946 cast = ctx->InsertInst(cast, conv, i);
6947 }
6948 } else {
6949 FATAL("TypeCastExpr::GetValue: problem with cast");
6950 }
6951
6952 return cast;
6953 } else if (toPointerType != NULL) {
6954 // int -> ptr
6955 if (toType->IsVaryingType() && fromType->IsUniformType())
6956 exprVal = ctx->SmearUniform(exprVal);
6957
6958 llvm::Type *llvmToType = toType->LLVMType(g->ctx);
6959 if (llvmToType == NULL)
6960 return NULL;
6961
6962 return ctx->IntToPtrInst(exprVal, llvmToType, "int_to_ptr");
6963 } else {
6964 const AtomicType *toAtomic = CastType<AtomicType>(toType);
6965 // typechecking should ensure this is the case
6966 AssertPos(pos, toAtomic != NULL);
6967
6968 return lTypeConvAtomic(ctx, exprVal, toAtomic, fromAtomic, pos);
6969 }
6970 }
6971
GetLValue(FunctionEmitContext * ctx) const6972 llvm::Value *TypeCastExpr::GetLValue(FunctionEmitContext *ctx) const {
6973 if (GetLValueType() != NULL) {
6974 return GetValue(ctx);
6975 } else {
6976 return NULL;
6977 }
6978 }
6979
GetType() const6980 const Type *TypeCastExpr::GetType() const {
6981 // Here we try to resolve situation where (base_type) can be treated as
6982 // (uniform base_type) of (varying base_type). This is a part of function
6983 // TypeCastExpr::TypeCheck. After implementation of operators we
6984 // have to have this functionality here.
6985 if (expr == NULL)
6986 return NULL;
6987 const Type *toType = type, *fromType = expr->GetType();
6988 if (toType == NULL || fromType == NULL)
6989 return NULL;
6990
6991 if (toType->HasUnboundVariability()) {
6992 if (fromType->IsUniformType()) {
6993 toType = type->ResolveUnboundVariability(Variability::Uniform);
6994 } else {
6995 toType = type->ResolveUnboundVariability(Variability::Varying);
6996 }
6997 }
6998 AssertPos(pos, toType->HasUnboundVariability() == false);
6999 return toType;
7000 }
7001
GetLValueType() const7002 const Type *TypeCastExpr::GetLValueType() const {
7003 AssertPos(pos, type->HasUnboundVariability() == false);
7004 if (CastType<PointerType>(GetType()) != NULL) {
7005 return type;
7006 } else {
7007 return NULL;
7008 }
7009 }
7010
lDeconstifyType(const Type * t)7011 static const Type *lDeconstifyType(const Type *t) {
7012 const PointerType *pt = CastType<PointerType>(t);
7013 if (pt != NULL)
7014 return new PointerType(lDeconstifyType(pt->GetBaseType()), pt->GetVariability(), false);
7015 else
7016 return t->GetAsNonConstType();
7017 }
7018
TypeCheck()7019 Expr *TypeCastExpr::TypeCheck() {
7020 if (expr == NULL)
7021 return NULL;
7022
7023 const Type *toType = type, *fromType = expr->GetType();
7024 if (toType == NULL || fromType == NULL)
7025 return NULL;
7026
7027 if (toType->HasUnboundVariability() && fromType->IsUniformType()) {
7028 TypeCastExpr *tce = new TypeCastExpr(toType->GetAsUniformType(), expr, pos);
7029 return ::TypeCheck(tce);
7030 }
7031 type = toType = type->ResolveUnboundVariability(Variability::Varying);
7032
7033 fromType = lDeconstifyType(fromType);
7034 toType = lDeconstifyType(toType);
7035
7036 // Anything can be cast to void...
7037 if (toType->IsVoidType())
7038 return this;
7039
7040 if (fromType->IsVoidType() || (fromType->IsVaryingType() && toType->IsUniformType())) {
7041 Error(pos, "Can't type cast from type \"%s\" to type \"%s\"", fromType->GetString().c_str(),
7042 toType->GetString().c_str());
7043 return NULL;
7044 }
7045
7046 // First some special cases that we allow only with an explicit type cast
7047 const PointerType *fromPtr = CastType<PointerType>(fromType);
7048 const PointerType *toPtr = CastType<PointerType>(toType);
7049 if (fromPtr != NULL && toPtr != NULL)
7050 // allow explicit typecasts between any two different pointer types
7051 return this;
7052
7053 const ReferenceType *fromRef = CastType<ReferenceType>(fromType);
7054 const ReferenceType *toRef = CastType<ReferenceType>(toType);
7055 if (fromRef != NULL && toRef != NULL) {
7056 // allow explicit typecasts between any two different reference types
7057 // Issues #721
7058 return this;
7059 }
7060
7061 const AtomicType *fromAtomic = CastType<AtomicType>(fromType);
7062 const AtomicType *toAtomic = CastType<AtomicType>(toType);
7063 const EnumType *fromEnum = CastType<EnumType>(fromType);
7064 const EnumType *toEnum = CastType<EnumType>(toType);
7065 if ((fromAtomic || fromEnum) && (toAtomic || toEnum))
7066 // Allow explicit casts between all of these
7067 return this;
7068
7069 // ptr -> int type casts
7070 if (fromPtr != NULL && toAtomic != NULL && toAtomic->IsIntType()) {
7071 bool safeCast =
7072 (toAtomic->basicType == AtomicType::TYPE_INT64 || toAtomic->basicType == AtomicType::TYPE_UINT64);
7073 if (g->target->is32Bit())
7074 safeCast |=
7075 (toAtomic->basicType == AtomicType::TYPE_INT32 || toAtomic->basicType == AtomicType::TYPE_UINT32);
7076 if (safeCast == false)
7077 Warning(pos,
7078 "Pointer type cast of type \"%s\" to integer type "
7079 "\"%s\" may lose information.",
7080 fromType->GetString().c_str(), toType->GetString().c_str());
7081 return this;
7082 }
7083
7084 // int -> ptr
7085 if (fromAtomic != NULL && fromAtomic->IsIntType() && toPtr != NULL)
7086 return this;
7087
7088 // And otherwise see if it's one of the conversions allowed to happen
7089 // implicitly.
7090 Expr *e = TypeConvertExpr(expr, toType, "type cast expression");
7091 if (e == NULL)
7092 return NULL;
7093 else
7094 return e;
7095 }
7096
Optimize()7097 Expr *TypeCastExpr::Optimize() {
7098 ConstExpr *constExpr = llvm::dyn_cast<ConstExpr>(expr);
7099 if (constExpr == NULL)
7100 // We can't do anything if this isn't a const expr
7101 return this;
7102
7103 const Type *toType = GetType();
7104 const AtomicType *toAtomic = CastType<AtomicType>(toType);
7105 const EnumType *toEnum = CastType<EnumType>(toType);
7106 // If we're not casting to an atomic or enum type, we can't do anything
7107 // here, since ConstExprs can only represent those two types. (So
7108 // e.g. we're casting from an int to an int<4>.)
7109 if (toAtomic == NULL && toEnum == NULL)
7110 return this;
7111
7112 bool forceVarying = toType->IsVaryingType();
7113
7114 // All of the type conversion smarts we need is already in the
7115 // ConstExpr::GetValues(), etc., methods, so we just need to call the
7116 // appropriate one for the type that this cast is converting to.
7117 AtomicType::BasicType basicType = toAtomic ? toAtomic->basicType : AtomicType::TYPE_UINT32;
7118 switch (basicType) {
7119 case AtomicType::TYPE_BOOL: {
7120 bool bv[ISPC_MAX_NVEC];
7121 constExpr->GetValues(bv, forceVarying);
7122 return new ConstExpr(toType, bv, pos);
7123 }
7124 case AtomicType::TYPE_INT8: {
7125 int8_t iv[ISPC_MAX_NVEC];
7126 constExpr->GetValues(iv, forceVarying);
7127 return new ConstExpr(toType, iv, pos);
7128 }
7129 case AtomicType::TYPE_UINT8: {
7130 uint8_t uv[ISPC_MAX_NVEC];
7131 constExpr->GetValues(uv, forceVarying);
7132 return new ConstExpr(toType, uv, pos);
7133 }
7134 case AtomicType::TYPE_INT16: {
7135 int16_t iv[ISPC_MAX_NVEC];
7136 constExpr->GetValues(iv, forceVarying);
7137 return new ConstExpr(toType, iv, pos);
7138 }
7139 case AtomicType::TYPE_UINT16: {
7140 uint16_t uv[ISPC_MAX_NVEC];
7141 constExpr->GetValues(uv, forceVarying);
7142 return new ConstExpr(toType, uv, pos);
7143 }
7144 case AtomicType::TYPE_INT32: {
7145 int32_t iv[ISPC_MAX_NVEC];
7146 constExpr->GetValues(iv, forceVarying);
7147 return new ConstExpr(toType, iv, pos);
7148 }
7149 case AtomicType::TYPE_UINT32: {
7150 uint32_t uv[ISPC_MAX_NVEC];
7151 constExpr->GetValues(uv, forceVarying);
7152 return new ConstExpr(toType, uv, pos);
7153 }
7154 case AtomicType::TYPE_FLOAT: {
7155 float fv[ISPC_MAX_NVEC];
7156 constExpr->GetValues(fv, forceVarying);
7157 return new ConstExpr(toType, fv, pos);
7158 }
7159 case AtomicType::TYPE_INT64: {
7160 int64_t iv[ISPC_MAX_NVEC];
7161 constExpr->GetValues(iv, forceVarying);
7162 return new ConstExpr(toType, iv, pos);
7163 }
7164 case AtomicType::TYPE_UINT64: {
7165 uint64_t uv[ISPC_MAX_NVEC];
7166 constExpr->GetValues(uv, forceVarying);
7167 return new ConstExpr(toType, uv, pos);
7168 }
7169 case AtomicType::TYPE_DOUBLE: {
7170 double dv[ISPC_MAX_NVEC];
7171 constExpr->GetValues(dv, forceVarying);
7172 return new ConstExpr(toType, dv, pos);
7173 }
7174 default:
7175 FATAL("unimplemented");
7176 }
7177 return this;
7178 }
7179
EstimateCost() const7180 int TypeCastExpr::EstimateCost() const {
7181 if (llvm::dyn_cast<ConstExpr>(expr) != NULL)
7182 return 0;
7183
7184 // FIXME: return COST_TYPECAST_COMPLEX when appropriate
7185 return COST_TYPECAST_SIMPLE;
7186 }
7187
Print() const7188 void TypeCastExpr::Print() const {
7189 printf("[%s] type cast (", GetType()->GetString().c_str());
7190 expr->Print();
7191 printf(")");
7192 pos.Print();
7193 }
7194
GetBaseSymbol() const7195 Symbol *TypeCastExpr::GetBaseSymbol() const { return expr ? expr->GetBaseSymbol() : NULL; }
7196
lConvertPointerConstant(llvm::Constant * c,const Type * constType)7197 static llvm::Constant *lConvertPointerConstant(llvm::Constant *c, const Type *constType) {
7198 if (c == NULL || constType->IsUniformType())
7199 return c;
7200
7201 // Handle conversion to int and then to vector of int or array of int
7202 // (for varying and soa types, respectively)
7203 llvm::Constant *intPtr = llvm::ConstantExpr::getPtrToInt(c, LLVMTypes::PointerIntType);
7204 Assert(constType->IsVaryingType() || constType->IsSOAType());
7205 int count = constType->IsVaryingType() ? g->target->getVectorWidth() : constType->GetSOAWidth();
7206
7207 std::vector<llvm::Constant *> smear;
7208 for (int i = 0; i < count; ++i)
7209 smear.push_back(intPtr);
7210
7211 if (constType->IsVaryingType())
7212 return llvm::ConstantVector::get(smear);
7213 else {
7214 llvm::ArrayType *at = llvm::ArrayType::get(LLVMTypes::PointerIntType, count);
7215 return llvm::ConstantArray::get(at, smear);
7216 }
7217 }
7218
GetConstant(const Type * constType) const7219 std::pair<llvm::Constant *, bool> TypeCastExpr::GetConstant(const Type *constType) const {
7220 // We don't need to worry about most the basic cases where the type
7221 // cast can resolve to a constant here, since the
7222 // TypeCastExpr::Optimize() method generally ends up doing the type
7223 // conversion and returning a ConstExpr, which in turn will have its
7224 // GetConstant() method called. However, because ConstExpr currently
7225 // can't represent pointer values, we have to handle a few cases
7226 // related to pointers here:
7227 //
7228 // 1. Null pointer (NULL, 0) valued initializers
7229 // 2. Converting function types to pointer-to-function types
7230 // 3. And converting these from uniform to the varying/soa equivalents.
7231 //
7232
7233 if ((CastType<PointerType>(constType) == NULL) && (llvm::dyn_cast<SizeOfExpr>(expr) == NULL))
7234 return std::pair<llvm::Constant *, bool>(NULL, false);
7235
7236 llvm::Value *ptr = NULL;
7237 if (GetBaseSymbol())
7238 ptr = GetBaseSymbol()->storagePtr;
7239
7240 if (ptr && llvm::dyn_cast<llvm::GlobalVariable>(ptr)) {
7241 if (CastType<ArrayType>(expr->GetType())) {
7242 if (llvm::Constant *c = llvm::dyn_cast<llvm::Constant>(ptr)) {
7243 llvm::Value *offsets[2] = {LLVMInt32(0), LLVMInt32(0)};
7244 llvm::ArrayRef<llvm::Value *> arrayRef(&offsets[0], &offsets[2]);
7245 llvm::Value *resultPtr = llvm::ConstantExpr::getGetElementPtr(PTYPE(c), c, arrayRef);
7246 if (resultPtr->getType() == constType->LLVMType(g->ctx)) {
7247 llvm::Constant *ret = llvm::dyn_cast<llvm::Constant>(resultPtr);
7248 return std::pair<llvm::Constant *, bool>(ret, false);
7249 }
7250 }
7251 }
7252 }
7253
7254 std::pair<llvm::Constant *, bool> cPair = expr->GetConstant(constType->GetAsUniformType());
7255 llvm::Constant *c = cPair.first;
7256 return std::pair<llvm::Constant *, bool>(lConvertPointerConstant(c, constType), cPair.second);
7257 }
7258
7259 ///////////////////////////////////////////////////////////////////////////
7260 // ReferenceExpr
7261
ReferenceExpr(Expr * e,SourcePos p)7262 ReferenceExpr::ReferenceExpr(Expr *e, SourcePos p) : Expr(p, ReferenceExprID) { expr = e; }
7263
GetValue(FunctionEmitContext * ctx) const7264 llvm::Value *ReferenceExpr::GetValue(FunctionEmitContext *ctx) const {
7265 ctx->SetDebugPos(pos);
7266 if (expr == NULL) {
7267 AssertPos(pos, m->errorCount > 0);
7268 return NULL;
7269 }
7270
7271 llvm::Value *value = expr->GetLValue(ctx);
7272 if (value != NULL)
7273 return value;
7274
7275 // value is NULL if the expression is a temporary; in this case, we'll
7276 // allocate storage for it so that we can return the pointer to that...
7277 const Type *type;
7278 if ((type = expr->GetType()) == NULL || type->LLVMType(g->ctx) == NULL) {
7279 AssertPos(pos, m->errorCount > 0);
7280 return NULL;
7281 }
7282
7283 value = expr->GetValue(ctx);
7284 if (value == NULL) {
7285 AssertPos(pos, m->errorCount > 0);
7286 return NULL;
7287 }
7288
7289 llvm::Value *ptr = ctx->AllocaInst(type);
7290 ctx->StoreInst(value, ptr, type, expr->GetType()->IsUniformType());
7291 return ptr;
7292 }
7293
GetBaseSymbol() const7294 Symbol *ReferenceExpr::GetBaseSymbol() const { return expr ? expr->GetBaseSymbol() : NULL; }
7295
GetType() const7296 const Type *ReferenceExpr::GetType() const {
7297 if (!expr)
7298 return NULL;
7299
7300 const Type *type = expr->GetType();
7301 if (!type)
7302 return NULL;
7303
7304 return new ReferenceType(type);
7305 }
7306
GetLValueType() const7307 const Type *ReferenceExpr::GetLValueType() const {
7308 if (!expr)
7309 return NULL;
7310
7311 const Type *type = expr->GetType();
7312 if (!type)
7313 return NULL;
7314
7315 return PointerType::GetUniform(type);
7316 }
7317
Optimize()7318 Expr *ReferenceExpr::Optimize() {
7319 if (expr == NULL)
7320 return NULL;
7321 return this;
7322 }
7323
TypeCheck()7324 Expr *ReferenceExpr::TypeCheck() {
7325 if (expr == NULL)
7326 return NULL;
7327 return this;
7328 }
7329
EstimateCost() const7330 int ReferenceExpr::EstimateCost() const { return 0; }
7331
Print() const7332 void ReferenceExpr::Print() const {
7333 if (expr == NULL || GetType() == NULL)
7334 return;
7335
7336 printf("[%s] &(", GetType()->GetString().c_str());
7337 expr->Print();
7338 printf(")");
7339 pos.Print();
7340 }
7341
7342 ///////////////////////////////////////////////////////////////////////////
7343 // DerefExpr
7344
DerefExpr(Expr * e,SourcePos p,unsigned scid)7345 DerefExpr::DerefExpr(Expr *e, SourcePos p, unsigned scid) : Expr(p, scid) { expr = e; }
7346
GetValue(FunctionEmitContext * ctx) const7347 llvm::Value *DerefExpr::GetValue(FunctionEmitContext *ctx) const {
7348 if (expr == NULL)
7349 return NULL;
7350 llvm::Value *ptr = expr->GetValue(ctx);
7351 if (ptr == NULL)
7352 return NULL;
7353 const Type *type = expr->GetType();
7354 if (type == NULL)
7355 return NULL;
7356
7357 if (lVaryingStructHasUniformMember(type, pos))
7358 return NULL;
7359
7360 // If dealing with 'varying * varying' add required offsets.
7361 ptr = lAddVaryingOffsetsIfNeeded(ctx, ptr, type);
7362
7363 Symbol *baseSym = expr->GetBaseSymbol();
7364 llvm::Value *mask = baseSym ? lMaskForSymbol(baseSym, ctx) : ctx->GetFullMask();
7365
7366 ctx->SetDebugPos(pos);
7367 return ctx->LoadInst(ptr, mask, type);
7368 }
7369
GetLValue(FunctionEmitContext * ctx) const7370 llvm::Value *DerefExpr::GetLValue(FunctionEmitContext *ctx) const {
7371 if (expr == NULL)
7372 return NULL;
7373 return expr->GetValue(ctx);
7374 }
7375
GetLValueType() const7376 const Type *DerefExpr::GetLValueType() const {
7377 if (expr == NULL)
7378 return NULL;
7379 return expr->GetType();
7380 }
7381
GetBaseSymbol() const7382 Symbol *DerefExpr::GetBaseSymbol() const { return expr ? expr->GetBaseSymbol() : NULL; }
7383
Optimize()7384 Expr *DerefExpr::Optimize() {
7385 if (expr == NULL)
7386 return NULL;
7387 return this;
7388 }
7389
7390 ///////////////////////////////////////////////////////////////////////////
7391 // PtrDerefExpr
7392
PtrDerefExpr(Expr * e,SourcePos p)7393 PtrDerefExpr::PtrDerefExpr(Expr *e, SourcePos p) : DerefExpr(e, p, PtrDerefExprID) {}
7394
GetType() const7395 const Type *PtrDerefExpr::GetType() const {
7396 const Type *type;
7397 if (expr == NULL || (type = expr->GetType()) == NULL) {
7398 AssertPos(pos, m->errorCount > 0);
7399 return NULL;
7400 }
7401 AssertPos(pos, CastType<PointerType>(type) != NULL);
7402
7403 if (type->IsUniformType())
7404 return type->GetBaseType();
7405 else
7406 return type->GetBaseType()->GetAsVaryingType();
7407 }
7408
TypeCheck()7409 Expr *PtrDerefExpr::TypeCheck() {
7410 const Type *type;
7411 if (expr == NULL || (type = expr->GetType()) == NULL) {
7412 AssertPos(pos, m->errorCount > 0);
7413 return NULL;
7414 }
7415
7416 if (const PointerType *pt = CastType<PointerType>(type)) {
7417 if (pt->GetBaseType()->IsVoidType()) {
7418 Error(pos, "Illegal to dereference void pointer type \"%s\".", type->GetString().c_str());
7419 return NULL;
7420 }
7421 } else {
7422 Error(pos, "Illegal to dereference non-pointer type \"%s\".", type->GetString().c_str());
7423 return NULL;
7424 }
7425
7426 return this;
7427 }
7428
EstimateCost() const7429 int PtrDerefExpr::EstimateCost() const {
7430 const Type *type;
7431 if (expr == NULL || (type = expr->GetType()) == NULL) {
7432 AssertPos(pos, m->errorCount > 0);
7433 return 0;
7434 }
7435
7436 if (type->IsVaryingType())
7437 // Be pessimistic; some of these will later be optimized into
7438 // vector loads/stores..
7439 return COST_GATHER + COST_DEREF;
7440 else
7441 return COST_DEREF;
7442 }
7443
Print() const7444 void PtrDerefExpr::Print() const {
7445 if (expr == NULL || GetType() == NULL)
7446 return;
7447
7448 printf("[%s] *(", GetType()->GetString().c_str());
7449 expr->Print();
7450 printf(")");
7451 pos.Print();
7452 }
7453
7454 ///////////////////////////////////////////////////////////////////////////
7455 // RefDerefExpr
7456
RefDerefExpr(Expr * e,SourcePos p)7457 RefDerefExpr::RefDerefExpr(Expr *e, SourcePos p) : DerefExpr(e, p, RefDerefExprID) {}
7458
GetType() const7459 const Type *RefDerefExpr::GetType() const {
7460 const Type *type;
7461 if (expr == NULL || (type = expr->GetType()) == NULL) {
7462 AssertPos(pos, m->errorCount > 0);
7463 return NULL;
7464 }
7465
7466 AssertPos(pos, CastType<ReferenceType>(type) != NULL);
7467 return type->GetReferenceTarget();
7468 }
7469
TypeCheck()7470 Expr *RefDerefExpr::TypeCheck() {
7471 const Type *type;
7472 if (expr == NULL || (type = expr->GetType()) == NULL) {
7473 AssertPos(pos, m->errorCount > 0);
7474 return NULL;
7475 }
7476
7477 // We only create RefDerefExprs internally for references in
7478 // expressions, so we should never create one with a non-reference
7479 // expression...
7480 AssertPos(pos, CastType<ReferenceType>(type) != NULL);
7481
7482 return this;
7483 }
7484
EstimateCost() const7485 int RefDerefExpr::EstimateCost() const {
7486 if (expr == NULL)
7487 return 0;
7488
7489 return COST_DEREF;
7490 }
7491
Print() const7492 void RefDerefExpr::Print() const {
7493 if (expr == NULL || GetType() == NULL)
7494 return;
7495
7496 printf("[%s] deref-reference (", GetType()->GetString().c_str());
7497 expr->Print();
7498 printf(")");
7499 pos.Print();
7500 }
7501
7502 ///////////////////////////////////////////////////////////////////////////
7503 // AddressOfExpr
7504
AddressOfExpr(Expr * e,SourcePos p)7505 AddressOfExpr::AddressOfExpr(Expr *e, SourcePos p) : Expr(p, AddressOfExprID), expr(e) {}
7506
GetValue(FunctionEmitContext * ctx) const7507 llvm::Value *AddressOfExpr::GetValue(FunctionEmitContext *ctx) const {
7508 ctx->SetDebugPos(pos);
7509 if (expr == NULL)
7510 return NULL;
7511
7512 const Type *exprType = expr->GetType();
7513 if (CastType<ReferenceType>(exprType) != NULL || CastType<FunctionType>(exprType) != NULL)
7514 return expr->GetValue(ctx);
7515 else
7516 return expr->GetLValue(ctx);
7517 }
7518
GetType() const7519 const Type *AddressOfExpr::GetType() const {
7520 if (expr == NULL)
7521 return NULL;
7522
7523 const Type *exprType = expr->GetType();
7524 if (CastType<ReferenceType>(exprType) != NULL)
7525 return PointerType::GetUniform(exprType->GetReferenceTarget());
7526
7527 const Type *t = expr->GetLValueType();
7528 if (t != NULL)
7529 return t;
7530 else {
7531 t = expr->GetType();
7532 if (t == NULL) {
7533 AssertPos(pos, m->errorCount > 0);
7534 return NULL;
7535 }
7536 return PointerType::GetUniform(t);
7537 }
7538 }
7539
GetLValueType() const7540 const Type *AddressOfExpr::GetLValueType() const {
7541 if (!expr)
7542 return NULL;
7543
7544 const Type *type = expr->GetType();
7545 if (!type)
7546 return NULL;
7547
7548 return PointerType::GetUniform(type);
7549 }
7550
GetBaseSymbol() const7551 Symbol *AddressOfExpr::GetBaseSymbol() const { return expr ? expr->GetBaseSymbol() : NULL; }
7552
Print() const7553 void AddressOfExpr::Print() const {
7554 printf("&(");
7555 if (expr)
7556 expr->Print();
7557 else
7558 printf("NULL expr");
7559 printf(")");
7560 pos.Print();
7561 }
7562
TypeCheck()7563 Expr *AddressOfExpr::TypeCheck() {
7564 const Type *exprType;
7565 if (expr == NULL || (exprType = expr->GetType()) == NULL) {
7566 AssertPos(pos, m->errorCount > 0);
7567 return NULL;
7568 }
7569
7570 if (CastType<ReferenceType>(exprType) != NULL || CastType<FunctionType>(exprType) != NULL) {
7571 return this;
7572 }
7573
7574 if (expr->GetLValueType() != NULL)
7575 return this;
7576
7577 Error(expr->pos, "Illegal to take address of non-lvalue or function.");
7578 return NULL;
7579 }
7580
Optimize()7581 Expr *AddressOfExpr::Optimize() { return this; }
7582
EstimateCost() const7583 int AddressOfExpr::EstimateCost() const { return 0; }
7584
GetConstant(const Type * type) const7585 std::pair<llvm::Constant *, bool> AddressOfExpr::GetConstant(const Type *type) const {
7586 if (expr == NULL || expr->GetType() == NULL) {
7587 AssertPos(pos, m->errorCount > 0);
7588 return std::pair<llvm::Constant *, bool>(NULL, false);
7589 }
7590
7591 const PointerType *pt = CastType<PointerType>(type);
7592 if (pt == NULL)
7593 return std::pair<llvm::Constant *, bool>(NULL, false);
7594
7595 bool isNotValidForMultiTargetGlobal = false;
7596 const FunctionType *ft = CastType<FunctionType>(pt->GetBaseType());
7597 if (ft != NULL) {
7598 std::pair<llvm::Constant *, bool> cPair = expr->GetConstant(ft);
7599 llvm::Constant *c = cPair.first;
7600 return std::pair<llvm::Constant *, bool>(lConvertPointerConstant(c, type), cPair.second);
7601 }
7602 llvm::Value *ptr = NULL;
7603 if (GetBaseSymbol())
7604 ptr = GetBaseSymbol()->storagePtr;
7605 if (ptr && llvm::dyn_cast<llvm::GlobalVariable>(ptr)) {
7606 const Type *eTYPE = GetType();
7607 if (type->LLVMType(g->ctx) == eTYPE->LLVMType(g->ctx)) {
7608 if (llvm::dyn_cast<SymbolExpr>(expr) != NULL) {
7609 return std::pair<llvm::Constant *, bool>(llvm::cast<llvm::Constant>(ptr),
7610 isNotValidForMultiTargetGlobal);
7611
7612 } else if (IndexExpr *IExpr = llvm::dyn_cast<IndexExpr>(expr)) {
7613 std::vector<llvm::Value *> gepIndex;
7614 Expr *mBaseExpr = NULL;
7615 while (IExpr) {
7616 std::pair<llvm::Constant *, bool> cIndexPair = IExpr->index->GetConstant(IExpr->index->GetType());
7617 llvm::Constant *cIndex = cIndexPair.first;
7618 if (cIndex == NULL)
7619 return std::pair<llvm::Constant *, bool>(NULL, false);
7620 gepIndex.insert(gepIndex.begin(), cIndex);
7621 mBaseExpr = IExpr->baseExpr;
7622 IExpr = llvm::dyn_cast<IndexExpr>(mBaseExpr);
7623 isNotValidForMultiTargetGlobal = isNotValidForMultiTargetGlobal || cIndexPair.second;
7624 }
7625 // The base expression needs to be a global symbol so that the
7626 // address is a constant.
7627 if (llvm::dyn_cast<SymbolExpr>(mBaseExpr) == NULL)
7628 return std::pair<llvm::Constant *, bool>(NULL, false);
7629 gepIndex.insert(gepIndex.begin(), LLVMInt64(0));
7630 llvm::Constant *c = llvm::cast<llvm::Constant>(ptr);
7631 llvm::Constant *c1 = llvm::ConstantExpr::getGetElementPtr(PTYPE(c), c, gepIndex);
7632 return std::pair<llvm::Constant *, bool>(c1, isNotValidForMultiTargetGlobal);
7633 }
7634 }
7635 }
7636 return std::pair<llvm::Constant *, bool>(NULL, false);
7637 }
7638
7639 ///////////////////////////////////////////////////////////////////////////
7640 // SizeOfExpr
7641
SizeOfExpr(Expr * e,SourcePos p)7642 SizeOfExpr::SizeOfExpr(Expr *e, SourcePos p) : Expr(p, SizeOfExprID), expr(e), type(NULL) {}
7643
SizeOfExpr(const Type * t,SourcePos p)7644 SizeOfExpr::SizeOfExpr(const Type *t, SourcePos p) : Expr(p, SizeOfExprID), expr(NULL), type(t) {
7645 type = type->ResolveUnboundVariability(Variability::Varying);
7646 }
7647
GetValue(FunctionEmitContext * ctx) const7648 llvm::Value *SizeOfExpr::GetValue(FunctionEmitContext *ctx) const {
7649 ctx->SetDebugPos(pos);
7650 const Type *t = expr ? expr->GetType() : type;
7651 if (t == NULL)
7652 return NULL;
7653
7654 llvm::Type *llvmType = t->LLVMType(g->ctx);
7655 if (llvmType == NULL)
7656 return NULL;
7657
7658 return g->target->SizeOf(llvmType, ctx->GetCurrentBasicBlock());
7659 }
7660
GetType() const7661 const Type *SizeOfExpr::GetType() const {
7662 return (g->target->is32Bit() || g->opt.force32BitAddressing) ? AtomicType::UniformUInt32
7663 : AtomicType::UniformUInt64;
7664 }
7665
Print() const7666 void SizeOfExpr::Print() const {
7667 printf("Sizeof (");
7668 if (expr != NULL)
7669 expr->Print();
7670 const Type *t = expr ? expr->GetType() : type;
7671 if (t != NULL)
7672 printf(" [type %s]", t->GetString().c_str());
7673 printf(")");
7674 pos.Print();
7675 }
7676
TypeCheck()7677 Expr *SizeOfExpr::TypeCheck() {
7678 // Can't compute the size of a struct without a definition
7679 if (type != NULL && CastType<UndefinedStructType>(type) != NULL) {
7680 Error(pos,
7681 "Can't compute the size of declared but not defined "
7682 "struct type \"%s\".",
7683 type->GetString().c_str());
7684 return NULL;
7685 }
7686
7687 return this;
7688 }
7689
Optimize()7690 Expr *SizeOfExpr::Optimize() { return this; }
7691
EstimateCost() const7692 int SizeOfExpr::EstimateCost() const { return 0; }
7693
GetConstant(const Type * rtype) const7694 std::pair<llvm::Constant *, bool> SizeOfExpr::GetConstant(const Type *rtype) const {
7695 const Type *t = expr ? expr->GetType() : type;
7696 if (t == NULL)
7697 return std::pair<llvm::Constant *, bool>(NULL, false);
7698
7699 bool isNotValidForMultiTargetGlobal = false;
7700 if (t->IsVaryingType())
7701 isNotValidForMultiTargetGlobal = true;
7702
7703 llvm::Type *llvmType = t->LLVMType(g->ctx);
7704 if (llvmType == NULL)
7705 return std::pair<llvm::Constant *, bool>(NULL, false);
7706
7707 uint64_t byteSize = g->target->getDataLayout()->getTypeStoreSize(llvmType);
7708 return std::pair<llvm::Constant *, bool>(llvm::ConstantInt::get(rtype->LLVMType(g->ctx), byteSize),
7709 isNotValidForMultiTargetGlobal);
7710 }
7711
7712 ///////////////////////////////////////////////////////////////////////////
7713 // AllocaExpr
7714
AllocaExpr(Expr * e,SourcePos p)7715 AllocaExpr::AllocaExpr(Expr *e, SourcePos p) : Expr(p, AllocaExprID), expr(e) {}
7716
GetValue(FunctionEmitContext * ctx) const7717 llvm::Value *AllocaExpr::GetValue(FunctionEmitContext *ctx) const {
7718 ctx->SetDebugPos(pos);
7719 if (expr == NULL)
7720 return NULL;
7721 llvm::Value *llvmValue = expr->GetValue(ctx);
7722 if (llvmValue == NULL)
7723 return NULL;
7724 llvm::Value *resultPtr = ctx->AllocaInst((LLVMTypes::VoidPointerType)->getElementType(), llvmValue, "allocaExpr",
7725 16, false); // 16 byte stack alignment.
7726 return resultPtr;
7727 }
7728
GetType() const7729 const Type *AllocaExpr::GetType() const { return PointerType::Void; }
7730
Print() const7731 void AllocaExpr::Print() const {
7732 printf("AllocaExpr (");
7733 if (expr != NULL)
7734 expr->Print();
7735 const Type *t = expr ? expr->GetType() : NULL;
7736 if (t != NULL)
7737 printf(" [type %s]", t->GetString().c_str());
7738 printf(")");
7739 pos.Print();
7740 }
7741
TypeCheck()7742 Expr *AllocaExpr::TypeCheck() {
7743 if (expr == NULL) {
7744 return NULL;
7745 }
7746
7747 if (g->target->isGenXTarget()) {
7748 Error(pos, "\"alloca()\" is not supported for genx-* targets yet.");
7749 return NULL;
7750 }
7751 const Type *argType = expr ? expr->GetType() : NULL;
7752 const Type *sizeType = m->symbolTable->LookupType("size_t");
7753 Assert(sizeType != NULL);
7754 if (!Type::Equal(sizeType->GetAsUniformType(), expr->GetType())) {
7755 expr = TypeConvertExpr(expr, sizeType->GetAsUniformType(), "Alloca_arg");
7756 }
7757 if (expr == NULL) {
7758 Error(pos, "\"alloca()\" cannot have an argument of type \"%s\".", argType->GetString().c_str());
7759 return NULL;
7760 }
7761
7762 return this;
7763 }
7764
Optimize()7765 Expr *AllocaExpr::Optimize() { return this; }
7766
EstimateCost() const7767 int AllocaExpr::EstimateCost() const { return 0; }
7768
7769 ///////////////////////////////////////////////////////////////////////////
7770 // SymbolExpr
7771
SymbolExpr(Symbol * s,SourcePos p)7772 SymbolExpr::SymbolExpr(Symbol *s, SourcePos p) : Expr(p, SymbolExprID) { symbol = s; }
7773
GetValue(FunctionEmitContext * ctx) const7774 llvm::Value *SymbolExpr::GetValue(FunctionEmitContext *ctx) const {
7775 // storagePtr may be NULL due to an earlier compilation error
7776 if (!symbol || !symbol->storagePtr)
7777 return NULL;
7778 ctx->SetDebugPos(pos);
7779
7780 std::string loadName = symbol->name + std::string("_load");
7781 #ifdef ISPC_GENX_ENABLED
7782 // TODO: this is a temporary workaround and will be changed as part
7783 // of SPIR-V emitting solution
7784 if (ctx->emitGenXHardwareMask() && symbol->name == "__mask") {
7785 return ctx->GenXSimdCFPredicate(LLVMMaskAllOn);
7786 }
7787 #endif
7788 return ctx->LoadInst(symbol->storagePtr, symbol->type, loadName.c_str());
7789 }
7790
GetLValue(FunctionEmitContext * ctx) const7791 llvm::Value *SymbolExpr::GetLValue(FunctionEmitContext *ctx) const {
7792 if (symbol == NULL)
7793 return NULL;
7794 ctx->SetDebugPos(pos);
7795 return symbol->storagePtr;
7796 }
7797
GetLValueType() const7798 const Type *SymbolExpr::GetLValueType() const {
7799 if (symbol == NULL)
7800 return NULL;
7801
7802 if (CastType<ReferenceType>(symbol->type) != NULL)
7803 return PointerType::GetUniform(symbol->type->GetReferenceTarget());
7804 else
7805 return PointerType::GetUniform(symbol->type);
7806 }
7807
GetBaseSymbol() const7808 Symbol *SymbolExpr::GetBaseSymbol() const { return symbol; }
7809
GetType() const7810 const Type *SymbolExpr::GetType() const { return symbol ? symbol->type : NULL; }
7811
TypeCheck()7812 Expr *SymbolExpr::TypeCheck() { return this; }
7813
Optimize()7814 Expr *SymbolExpr::Optimize() {
7815 if (symbol == NULL)
7816 return NULL;
7817 else if (symbol->constValue != NULL) {
7818 AssertPos(pos, GetType()->IsConstType());
7819 return new ConstExpr(symbol->constValue, pos);
7820 } else
7821 return this;
7822 }
7823
EstimateCost() const7824 int SymbolExpr::EstimateCost() const {
7825 // Be optimistic and assume it's in a register or can be used as a
7826 // memory operand..
7827 return 0;
7828 }
7829
Print() const7830 void SymbolExpr::Print() const {
7831 if (symbol == NULL || GetType() == NULL)
7832 return;
7833
7834 printf("[%s] sym: (%s)", GetType()->GetString().c_str(), symbol->name.c_str());
7835 pos.Print();
7836 }
7837
7838 ///////////////////////////////////////////////////////////////////////////
7839 // FunctionSymbolExpr
7840
FunctionSymbolExpr(const char * n,const std::vector<Symbol * > & candidates,SourcePos p)7841 FunctionSymbolExpr::FunctionSymbolExpr(const char *n, const std::vector<Symbol *> &candidates, SourcePos p)
7842 : Expr(p, FunctionSymbolExprID) {
7843 name = n;
7844 candidateFunctions = candidates;
7845 matchingFunc = (candidates.size() == 1) ? candidates[0] : NULL;
7846 triedToResolve = false;
7847 }
7848
GetType() const7849 const Type *FunctionSymbolExpr::GetType() const {
7850 if (triedToResolve == false && matchingFunc == NULL) {
7851 Error(pos, "Ambiguous use of overloaded function \"%s\".", name.c_str());
7852 return NULL;
7853 }
7854
7855 return matchingFunc ? matchingFunc->type : NULL;
7856 }
7857
GetValue(FunctionEmitContext * ctx) const7858 llvm::Value *FunctionSymbolExpr::GetValue(FunctionEmitContext *ctx) const {
7859 return matchingFunc ? matchingFunc->function : NULL;
7860 }
7861
GetBaseSymbol() const7862 Symbol *FunctionSymbolExpr::GetBaseSymbol() const { return matchingFunc; }
7863
TypeCheck()7864 Expr *FunctionSymbolExpr::TypeCheck() { return this; }
7865
Optimize()7866 Expr *FunctionSymbolExpr::Optimize() { return this; }
7867
EstimateCost() const7868 int FunctionSymbolExpr::EstimateCost() const { return 0; }
7869
Print() const7870 void FunctionSymbolExpr::Print() const {
7871 if (!matchingFunc || !GetType())
7872 return;
7873
7874 printf("[%s] fun sym (%s)", GetType()->GetString().c_str(), matchingFunc->name.c_str());
7875 pos.Print();
7876 }
7877
GetConstant(const Type * type) const7878 std::pair<llvm::Constant *, bool> FunctionSymbolExpr::GetConstant(const Type *type) const {
7879 if (matchingFunc == NULL || matchingFunc->function == NULL)
7880 return std::pair<llvm::Constant *, bool>(NULL, false);
7881
7882 const FunctionType *ft = CastType<FunctionType>(type);
7883 if (ft == NULL)
7884 return std::pair<llvm::Constant *, bool>(NULL, false);
7885
7886 if (Type::Equal(type, matchingFunc->type) == false) {
7887 Error(pos,
7888 "Type of function symbol \"%s\" doesn't match expected type "
7889 "\"%s\".",
7890 matchingFunc->type->GetString().c_str(), type->GetString().c_str());
7891 return std::pair<llvm::Constant *, bool>(NULL, false);
7892 }
7893
7894 return std::pair<llvm::Constant *, bool>(matchingFunc->function, false);
7895 }
7896
lGetOverloadCandidateMessage(const std::vector<Symbol * > & funcs,const std::vector<const Type * > & argTypes,const std::vector<bool> * argCouldBeNULL)7897 static std::string lGetOverloadCandidateMessage(const std::vector<Symbol *> &funcs,
7898 const std::vector<const Type *> &argTypes,
7899 const std::vector<bool> *argCouldBeNULL) {
7900 std::string message = "Passed types: (";
7901 for (unsigned int i = 0; i < argTypes.size(); ++i) {
7902 if (argTypes[i] != NULL)
7903 message += argTypes[i]->GetString();
7904 else
7905 message += "(unknown type)";
7906 message += (i < argTypes.size() - 1) ? ", " : ")\n";
7907 }
7908
7909 for (unsigned int i = 0; i < funcs.size(); ++i) {
7910 const FunctionType *ft = CastType<FunctionType>(funcs[i]->type);
7911 Assert(ft != NULL);
7912 message += "Candidate: ";
7913 message += ft->GetString();
7914 if (i < funcs.size() - 1)
7915 message += "\n";
7916 }
7917 return message;
7918 }
7919
7920 /** Helper function used for function overload resolution: returns true if
7921 converting the argument to the call type only requires a type
7922 conversion that won't lose information. Otherwise return false.
7923 */
lIsMatchWithTypeWidening(const Type * callType,const Type * funcArgType)7924 static bool lIsMatchWithTypeWidening(const Type *callType, const Type *funcArgType) {
7925 const AtomicType *callAt = CastType<AtomicType>(callType);
7926 const AtomicType *funcAt = CastType<AtomicType>(funcArgType);
7927 if (callAt == NULL || funcAt == NULL)
7928 return false;
7929
7930 if (callAt->IsUniformType() != funcAt->IsUniformType())
7931 return false;
7932
7933 switch (callAt->basicType) {
7934 case AtomicType::TYPE_BOOL:
7935 return true;
7936 case AtomicType::TYPE_INT8:
7937 case AtomicType::TYPE_UINT8:
7938 return (funcAt->basicType != AtomicType::TYPE_BOOL);
7939 case AtomicType::TYPE_INT16:
7940 case AtomicType::TYPE_UINT16:
7941 return (funcAt->basicType != AtomicType::TYPE_BOOL && funcAt->basicType != AtomicType::TYPE_INT8 &&
7942 funcAt->basicType != AtomicType::TYPE_UINT8);
7943 case AtomicType::TYPE_INT32:
7944 case AtomicType::TYPE_UINT32:
7945 return (funcAt->basicType == AtomicType::TYPE_INT32 || funcAt->basicType == AtomicType::TYPE_UINT32 ||
7946 funcAt->basicType == AtomicType::TYPE_INT64 || funcAt->basicType == AtomicType::TYPE_UINT64);
7947 case AtomicType::TYPE_FLOAT:
7948 return (funcAt->basicType == AtomicType::TYPE_DOUBLE);
7949 case AtomicType::TYPE_INT64:
7950 case AtomicType::TYPE_UINT64:
7951 return (funcAt->basicType == AtomicType::TYPE_INT64 || funcAt->basicType == AtomicType::TYPE_UINT64);
7952 case AtomicType::TYPE_DOUBLE:
7953 return false;
7954 default:
7955 FATAL("Unhandled atomic type");
7956 return false;
7957 }
7958 }
7959
7960 /* Returns the set of function overloads that are potential matches, given
7961 argCount values being passed as arguments to the function call.
7962 */
getCandidateFunctions(int argCount) const7963 std::vector<Symbol *> FunctionSymbolExpr::getCandidateFunctions(int argCount) const {
7964 std::vector<Symbol *> ret;
7965 for (int i = 0; i < (int)candidateFunctions.size(); ++i) {
7966 const FunctionType *ft = CastType<FunctionType>(candidateFunctions[i]->type);
7967 AssertPos(pos, ft != NULL);
7968
7969 // There's no way to match if the caller is passing more arguments
7970 // than this function instance takes.
7971 if (argCount > ft->GetNumParameters())
7972 continue;
7973
7974 // Not enough arguments, and no default argument value to save us
7975 if (argCount < ft->GetNumParameters() && ft->GetParameterDefault(argCount) == NULL)
7976 continue;
7977
7978 // Success
7979 ret.push_back(candidateFunctions[i]);
7980 }
7981 return ret;
7982 }
7983
lArgIsPointerType(const Type * type)7984 static bool lArgIsPointerType(const Type *type) {
7985 if (CastType<PointerType>(type) != NULL)
7986 return true;
7987
7988 const ReferenceType *rt = CastType<ReferenceType>(type);
7989 if (rt == NULL)
7990 return false;
7991
7992 const Type *t = rt->GetReferenceTarget();
7993 return (CastType<PointerType>(t) != NULL);
7994 }
7995
7996 /** This function computes the value of a cost function that represents the
7997 cost of calling a function of the given type with arguments of the
7998 given types. If it's not possible to call the function, regardless of
7999 any type conversions applied, a cost of -1 is returned.
8000 */
computeOverloadCost(const FunctionType * ftype,const std::vector<const Type * > & argTypes,const std::vector<bool> * argCouldBeNULL,const std::vector<bool> * argIsConstant,int * cost)8001 int FunctionSymbolExpr::computeOverloadCost(const FunctionType *ftype, const std::vector<const Type *> &argTypes,
8002 const std::vector<bool> *argCouldBeNULL,
8003 const std::vector<bool> *argIsConstant, int *cost) {
8004 int costSum = 0;
8005
8006 // In computing the cost function, we only worry about the actual
8007 // argument types--using function default parameter values is free for
8008 // the purposes here...
8009 for (int i = 0; i < (int)argTypes.size(); ++i) {
8010 cost[i] = 0;
8011 // The cost imposed by this argument will be a multiple of
8012 // costScale, which has a value set so that for each of the cost
8013 // buckets, even if all of the function arguments undergo the next
8014 // lower-cost conversion, the sum of their costs will be less than
8015 // a single instance of the next higher-cost conversion.
8016 int costScale = argTypes.size() + 1;
8017
8018 const Type *fargType = ftype->GetParameterType(i);
8019 const Type *callType = argTypes[i];
8020
8021 if (Type::Equal(callType, fargType))
8022 // Perfect match: no cost
8023 // Step "1" from documentation
8024 cost[i] += 0;
8025 else if (argCouldBeNULL && (*argCouldBeNULL)[i] && lArgIsPointerType(fargType))
8026 // Passing NULL to a pointer-typed parameter is also a no-cost operation
8027 // Step "1" from documentation
8028 cost[i] += 0;
8029 else {
8030 // If the argument is a compile-time constant, we'd like to
8031 // count the cost of various conversions as much lower than the
8032 // cost if it wasn't--so scale up the cost when this isn't the
8033 // case..
8034 if (argIsConstant == NULL || (*argIsConstant)[i] == false)
8035 costScale *= 512;
8036
8037 if (CastType<ReferenceType>(fargType)) {
8038 // Here we completely handle the case where fargType is reference.
8039 if (callType->IsConstType() && !fargType->IsConstType()) {
8040 // It is forbidden to pass const object to non-const reference (cvf -> vfr)
8041 return -1;
8042 }
8043 if (!callType->IsConstType() && fargType->IsConstType()) {
8044 // It is possible to pass (vf -> cvfr)
8045 // but it is worse than (vf -> vfr) or (cvf -> cvfr)
8046 // Step "3" from documentation
8047 cost[i] += 2 * costScale;
8048 }
8049 if (!Type::Equal(callType->GetReferenceTarget()->GetAsNonConstType(),
8050 fargType->GetReferenceTarget()->GetAsNonConstType())) {
8051 // Types under references must be equal completely.
8052 // vd -> vfr or vd -> cvfr are forbidden. (Although clang allows vd -> cvfr case.)
8053 return -1;
8054 }
8055 // penalty for equal types under reference (vf -> vfr is worse than vf -> vf)
8056 // Step "2" from documentation
8057 cost[i] += 2 * costScale;
8058 continue;
8059 }
8060 const Type *callTypeNP = callType;
8061 if (CastType<ReferenceType>(callType)) {
8062 callTypeNP = callType->GetReferenceTarget();
8063 // we can treat vfr as vf for callType with some penalty
8064 // Step "5" from documentation
8065 cost[i] += 2 * costScale;
8066 }
8067
8068 // Now we deal with references, so we can normalize to non-const types
8069 // because we're passing by value anyway, so const doesn't matter.
8070 const Type *callTypeNC = callTypeNP->GetAsNonConstType();
8071 const Type *fargTypeNC = fargType->GetAsNonConstType();
8072
8073 // Now we forget about constants and references!
8074 if (Type::EqualIgnoringConst(callTypeNP, fargType)) {
8075 // The best case: vf -> vf.
8076 // Step "4" from documentation
8077 cost[i] += 1 * costScale;
8078 continue;
8079 }
8080 if (lIsMatchWithTypeWidening(callTypeNC, fargTypeNC)) {
8081 // A little bit worse case: vf -> vd.
8082 // Step "6" from documentation
8083 cost[i] += 8 * costScale;
8084 continue;
8085 }
8086 if (fargType->IsVaryingType() && callType->IsUniformType()) {
8087 // Here we deal with brodcasting uniform to varying.
8088 // callType - varying and fargType - uniform is forbidden.
8089 if (Type::Equal(callTypeNC->GetAsVaryingType(), fargTypeNC)) {
8090 // uf -> vf is better than uf -> ui or uf -> ud
8091 // Step "7" from documentation
8092 cost[i] += 16 * costScale;
8093 continue;
8094 }
8095 if (lIsMatchWithTypeWidening(callTypeNC->GetAsVaryingType(), fargTypeNC)) {
8096 // uf -> vd is better than uf -> vi (128 < 128 + 64)
8097 // but worse than uf -> ui (128 > 64)
8098 // Step "9" from documentation
8099 cost[i] += 128 * costScale;
8100 continue;
8101 }
8102 // 128 + 64 is the max. uf -> vi is the worst case.
8103 // Step "10" from documentation
8104 cost[i] += 128 * costScale;
8105 }
8106 if (CanConvertTypes(callTypeNC, fargTypeNC))
8107 // two cases: the worst is 128 + 64: uf -> vi and
8108 // the only 64: (64 < 128) uf -> ui worse than uf -> vd
8109 // Step "8" from documentation
8110 cost[i] += 64 * costScale;
8111 else
8112 // Failure--no type conversion possible...
8113 return -1;
8114 }
8115 }
8116
8117 for (int i = 0; i < (int)argTypes.size(); ++i) {
8118 costSum = costSum + cost[i];
8119 }
8120 return costSum;
8121 }
8122
ResolveOverloads(SourcePos argPos,const std::vector<const Type * > & argTypes,const std::vector<bool> * argCouldBeNULL,const std::vector<bool> * argIsConstant)8123 bool FunctionSymbolExpr::ResolveOverloads(SourcePos argPos, const std::vector<const Type *> &argTypes,
8124 const std::vector<bool> *argCouldBeNULL,
8125 const std::vector<bool> *argIsConstant) {
8126 const char *funName = candidateFunctions.front()->name.c_str();
8127 if (triedToResolve == true) {
8128 return true;
8129 }
8130
8131 triedToResolve = true;
8132
8133 // Functions with names that start with "__" should only be various
8134 // builtins. For those, we'll demand an exact match, since we'll
8135 // expect whichever function in stdlib.ispc is calling out to one of
8136 // those to be matching the argument types exactly; this is to be a bit
8137 // extra safe to be sure that the expected builtin is in fact being
8138 // called.
8139 bool exactMatchOnly = (name.substr(0, 2) == "__");
8140
8141 // First, find the subset of overload candidates that take the same
8142 // number of arguments as have parameters (including functions that
8143 // take more arguments but have defaults starting no later than after
8144 // our last parameter).
8145 std::vector<Symbol *> actualCandidates = getCandidateFunctions(argTypes.size());
8146
8147 int bestMatchCost = 1 << 30;
8148 std::vector<Symbol *> matches;
8149 std::vector<int> candidateCosts;
8150 std::vector<int *> candidateExpandCosts;
8151
8152 if (actualCandidates.size() == 0)
8153 goto failure;
8154
8155 // Compute the cost for calling each of the candidate functions
8156 for (int i = 0; i < (int)actualCandidates.size(); ++i) {
8157 const FunctionType *ft = CastType<FunctionType>(actualCandidates[i]->type);
8158 AssertPos(pos, ft != NULL);
8159 int *cost = new int[argTypes.size()];
8160 candidateCosts.push_back(computeOverloadCost(ft, argTypes, argCouldBeNULL, argIsConstant, cost));
8161 candidateExpandCosts.push_back(cost);
8162 }
8163
8164 // Find the best cost, and then the candidate or candidates that have
8165 // that cost.
8166 for (int i = 0; i < (int)candidateCosts.size(); ++i) {
8167 if (candidateCosts[i] != -1 && candidateCosts[i] < bestMatchCost)
8168 bestMatchCost = candidateCosts[i];
8169 }
8170 // None of the candidates matched
8171 if (bestMatchCost == (1 << 30))
8172 goto failure;
8173 for (int i = 0; i < (int)candidateCosts.size(); ++i) {
8174 if (candidateCosts[i] == bestMatchCost) {
8175 for (int j = 0; j < (int)candidateCosts.size(); ++j) {
8176 for (int k = 0; k < argTypes.size(); k++) {
8177 if (candidateCosts[j] != -1 && candidateExpandCosts[j][k] < candidateExpandCosts[i][k]) {
8178 std::vector<Symbol *> temp;
8179 temp.push_back(actualCandidates[i]);
8180 temp.push_back(actualCandidates[j]);
8181 std::string candidateMessage = lGetOverloadCandidateMessage(temp, argTypes, argCouldBeNULL);
8182 Warning(pos,
8183 "call to \"%s\" is ambiguous. "
8184 "This warning will be turned into error in the next ispc release.\n"
8185 "Please add explicit cast to arguments to have unambiguous match."
8186 "\n%s",
8187 funName, candidateMessage.c_str());
8188 }
8189 }
8190 }
8191 matches.push_back(actualCandidates[i]);
8192 }
8193 }
8194 for (int i = 0; i < (int)candidateExpandCosts.size(); ++i) {
8195 delete[] candidateExpandCosts[i];
8196 }
8197
8198 if (matches.size() == 1) {
8199 // Only one match: success
8200 matchingFunc = matches[0];
8201 return true;
8202 } else if (matches.size() > 1) {
8203 // Multiple matches: ambiguous
8204 std::string candidateMessage = lGetOverloadCandidateMessage(matches, argTypes, argCouldBeNULL);
8205 Error(pos,
8206 "Multiple overloaded functions matched call to function "
8207 "\"%s\"%s.\n%s",
8208 funName, exactMatchOnly ? " only considering exact matches" : "", candidateMessage.c_str());
8209 return false;
8210 } else {
8211 // No matches at all
8212 failure:
8213 std::string candidateMessage = lGetOverloadCandidateMessage(matches, argTypes, argCouldBeNULL);
8214 Error(pos,
8215 "Unable to find any matching overload for call to function "
8216 "\"%s\"%s.\n%s",
8217 funName, exactMatchOnly ? " only considering exact matches" : "", candidateMessage.c_str());
8218 return false;
8219 }
8220 }
8221
GetMatchingFunction()8222 Symbol *FunctionSymbolExpr::GetMatchingFunction() { return matchingFunc; }
8223
8224 ///////////////////////////////////////////////////////////////////////////
8225 // SyncExpr
8226
GetType() const8227 const Type *SyncExpr::GetType() const { return AtomicType::Void; }
8228
GetValue(FunctionEmitContext * ctx) const8229 llvm::Value *SyncExpr::GetValue(FunctionEmitContext *ctx) const {
8230 ctx->SetDebugPos(pos);
8231 ctx->SyncInst();
8232 return NULL;
8233 }
8234
EstimateCost() const8235 int SyncExpr::EstimateCost() const { return COST_SYNC; }
8236
Print() const8237 void SyncExpr::Print() const {
8238 printf("sync");
8239 pos.Print();
8240 }
8241
TypeCheck()8242 Expr *SyncExpr::TypeCheck() { return this; }
8243
Optimize()8244 Expr *SyncExpr::Optimize() { return this; }
8245
8246 ///////////////////////////////////////////////////////////////////////////
8247 // NullPointerExpr
8248
GetValue(FunctionEmitContext * ctx) const8249 llvm::Value *NullPointerExpr::GetValue(FunctionEmitContext *ctx) const {
8250 return llvm::ConstantPointerNull::get(LLVMTypes::VoidPointerType);
8251 }
8252
GetType() const8253 const Type *NullPointerExpr::GetType() const { return PointerType::Void; }
8254
TypeCheck()8255 Expr *NullPointerExpr::TypeCheck() { return this; }
8256
Optimize()8257 Expr *NullPointerExpr::Optimize() { return this; }
8258
GetConstant(const Type * type) const8259 std::pair<llvm::Constant *, bool> NullPointerExpr::GetConstant(const Type *type) const {
8260 const PointerType *pt = CastType<PointerType>(type);
8261 if (pt == NULL)
8262 return std::pair<llvm::Constant *, bool>(NULL, false);
8263
8264 llvm::Type *llvmType = type->LLVMType(g->ctx);
8265 if (llvmType == NULL) {
8266 AssertPos(pos, m->errorCount > 0);
8267 return std::pair<llvm::Constant *, bool>(NULL, false);
8268 }
8269
8270 return std::pair<llvm::Constant *, bool>(llvm::Constant::getNullValue(llvmType), false);
8271 }
8272
Print() const8273 void NullPointerExpr::Print() const {
8274 printf("NULL");
8275 pos.Print();
8276 }
8277
EstimateCost() const8278 int NullPointerExpr::EstimateCost() const { return 0; }
8279
8280 ///////////////////////////////////////////////////////////////////////////
8281 // NewExpr
8282
NewExpr(int typeQual,const Type * t,Expr * init,Expr * count,SourcePos tqPos,SourcePos p)8283 NewExpr::NewExpr(int typeQual, const Type *t, Expr *init, Expr *count, SourcePos tqPos, SourcePos p)
8284 : Expr(p, NewExprID) {
8285 allocType = t;
8286
8287 initExpr = init;
8288 countExpr = count;
8289
8290 /* (The below cases actually should be impossible, since the parser
8291 doesn't allow more than a single type qualifier before a "new".) */
8292 if ((typeQual & ~(TYPEQUAL_UNIFORM | TYPEQUAL_VARYING)) != 0) {
8293 Error(tqPos, "Illegal type qualifiers in \"new\" expression (only "
8294 "\"uniform\" and \"varying\" are allowed.");
8295 isVarying = false;
8296 } else if ((typeQual & TYPEQUAL_UNIFORM) != 0 && (typeQual & TYPEQUAL_VARYING) != 0) {
8297 Error(tqPos, "Illegal to provide both \"uniform\" and \"varying\" "
8298 "qualifiers to \"new\" expression.");
8299 isVarying = false;
8300 } else
8301 // If no type qualifier is given before the 'new', treat it as a
8302 // varying new.
8303 isVarying = (typeQual == 0) || (typeQual & TYPEQUAL_VARYING);
8304
8305 if (allocType != NULL)
8306 allocType = allocType->ResolveUnboundVariability(Variability::Uniform);
8307 }
8308
GetValue(FunctionEmitContext * ctx) const8309 llvm::Value *NewExpr::GetValue(FunctionEmitContext *ctx) const {
8310 bool do32Bit = (g->target->is32Bit() || g->opt.force32BitAddressing);
8311
8312 // Determine how many elements we need to allocate. Note that this
8313 // will be a varying value if this is a varying new.
8314 llvm::Value *countValue;
8315 if (countExpr != NULL) {
8316 countValue = countExpr->GetValue(ctx);
8317 if (countValue == NULL) {
8318 AssertPos(pos, m->errorCount > 0);
8319 return NULL;
8320 }
8321 } else {
8322 if (isVarying) {
8323 if (do32Bit)
8324 countValue = LLVMInt32Vector(1);
8325 else
8326 countValue = LLVMInt64Vector(1);
8327 } else {
8328 if (do32Bit)
8329 countValue = LLVMInt32(1);
8330 else
8331 countValue = LLVMInt64(1);
8332 }
8333 }
8334
8335 // Compute the total amount of memory to allocate, allocSize, as the
8336 // product of the number of elements to allocate and the size of a
8337 // single element.
8338 llvm::Value *eltSize = g->target->SizeOf(allocType->LLVMType(g->ctx), ctx->GetCurrentBasicBlock());
8339 if (isVarying)
8340 eltSize = ctx->SmearUniform(eltSize, "smear_size");
8341 llvm::Value *allocSize = ctx->BinaryOperator(llvm::Instruction::Mul, countValue, eltSize, "alloc_size");
8342
8343 // Determine which allocation builtin function to call: uniform or
8344 // varying, and taking 32-bit or 64-bit allocation counts.
8345 llvm::Function *func;
8346 if (isVarying) {
8347 if (g->target->is32Bit()) {
8348 func = m->module->getFunction("__new_varying32_32rt");
8349 } else if (g->opt.force32BitAddressing) {
8350 func = m->module->getFunction("__new_varying32_64rt");
8351 } else {
8352 func = m->module->getFunction("__new_varying64_64rt");
8353 }
8354 } else {
8355 // FIXME: __new_uniform_32rt should take i32
8356 if (allocSize->getType() != LLVMTypes::Int64Type)
8357 allocSize = ctx->SExtInst(allocSize, LLVMTypes::Int64Type, "alloc_size64");
8358 if (g->target->is32Bit()) {
8359 func = m->module->getFunction("__new_uniform_32rt");
8360 } else {
8361 func = m->module->getFunction("__new_uniform_64rt");
8362 }
8363 }
8364 AssertPos(pos, func != NULL);
8365
8366 // Make the call for the the actual allocation.
8367 llvm::Value *ptrValue = ctx->CallInst(func, NULL, allocSize, "new");
8368
8369 // Now handle initializers and returning the right type for the result.
8370 const Type *retType = GetType();
8371 if (retType == NULL)
8372 return NULL;
8373 if (isVarying) {
8374 if (g->target->is32Bit())
8375 // Convert i64 vector values to i32 if we are compiling to a
8376 // 32-bit target.
8377 ptrValue = ctx->TruncInst(ptrValue, LLVMTypes::VoidPointerVectorType, "ptr_to_32bit");
8378
8379 if (initExpr != NULL) {
8380 // If we have an initializer expression, emit code that checks
8381 // to see if each lane is active and if so, runs the code to do
8382 // the initialization. Note that we're we're taking advantage
8383 // of the fact that the __new_varying*() functions are
8384 // implemented to return NULL for program instances that aren't
8385 // executing; more generally, we should be using the current
8386 // execution mask for this...
8387 for (int i = 0; i < g->target->getVectorWidth(); ++i) {
8388 llvm::BasicBlock *bbInit = ctx->CreateBasicBlock("init_ptr");
8389 llvm::BasicBlock *bbSkip = ctx->CreateBasicBlock("skip_init");
8390 llvm::Value *p = ctx->ExtractInst(ptrValue, i);
8391 llvm::Value *nullValue = g->target->is32Bit() ? LLVMInt32(0) : LLVMInt64(0);
8392 // Is the pointer for the current lane non-zero?
8393 llvm::Value *nonNull =
8394 ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE, p, nullValue, "non_null");
8395 ctx->BranchInst(bbInit, bbSkip, nonNull);
8396
8397 // Initialize the memory pointed to by the pointer for the
8398 // current lane.
8399 ctx->SetCurrentBasicBlock(bbInit);
8400 llvm::Type *ptrType = retType->GetAsUniformType()->LLVMType(g->ctx);
8401 llvm::Value *ptr = ctx->IntToPtrInst(p, ptrType);
8402 InitSymbol(ptr, allocType, initExpr, ctx, pos);
8403 ctx->BranchInst(bbSkip);
8404
8405 ctx->SetCurrentBasicBlock(bbSkip);
8406 }
8407 }
8408
8409 return ptrValue;
8410 } else {
8411 // For uniform news, we just need to cast the void * to be a
8412 // pointer of the return type and to run the code for initializers,
8413 // if present.
8414 llvm::Type *ptrType = retType->LLVMType(g->ctx);
8415 ptrValue = ctx->BitCastInst(ptrValue, ptrType, llvm::Twine(ptrValue->getName()) + "_cast_ptr");
8416
8417 if (initExpr != NULL)
8418 InitSymbol(ptrValue, allocType, initExpr, ctx, pos);
8419
8420 return ptrValue;
8421 }
8422 }
8423
GetType() const8424 const Type *NewExpr::GetType() const {
8425 if (allocType == NULL)
8426 return NULL;
8427
8428 return isVarying ? PointerType::GetVarying(allocType) : PointerType::GetUniform(allocType);
8429 }
8430
TypeCheck()8431 Expr *NewExpr::TypeCheck() {
8432 // It's illegal to call new with an undefined struct type
8433 if (allocType == NULL) {
8434 AssertPos(pos, m->errorCount > 0);
8435 return NULL;
8436 }
8437
8438 if (g->target->isGenXTarget()) {
8439 Error(pos, "\"new\" is not supported for genx-* targets yet.");
8440 return NULL;
8441 }
8442
8443 if (CastType<UndefinedStructType>(allocType) != NULL) {
8444 Error(pos,
8445 "Can't dynamically allocate storage for declared "
8446 "but not defined type \"%s\".",
8447 allocType->GetString().c_str());
8448 return NULL;
8449 }
8450 const StructType *st = CastType<StructType>(allocType);
8451 if (st != NULL && !st->IsDefined()) {
8452 Error(pos,
8453 "Can't dynamically allocate storage for declared "
8454 "type \"%s\" containing undefined member type.",
8455 allocType->GetString().c_str());
8456 return NULL;
8457 }
8458
8459 // Otherwise we only need to make sure that if we have an expression
8460 // giving a number of elements to allocate that it can be converted to
8461 // an integer of the appropriate variability.
8462 if (countExpr == NULL)
8463 return this;
8464
8465 const Type *countType;
8466 if ((countType = countExpr->GetType()) == NULL)
8467 return NULL;
8468
8469 if (isVarying == false && countType->IsVaryingType()) {
8470 Error(pos, "Illegal to provide \"varying\" allocation count with "
8471 "\"uniform new\" expression.");
8472 return NULL;
8473 }
8474
8475 // Figure out the type that the allocation count should be
8476 const Type *t =
8477 (g->target->is32Bit() || g->opt.force32BitAddressing) ? AtomicType::UniformUInt32 : AtomicType::UniformUInt64;
8478 if (isVarying)
8479 t = t->GetAsVaryingType();
8480
8481 countExpr = TypeConvertExpr(countExpr, t, "item count");
8482 if (countExpr == NULL)
8483 return NULL;
8484
8485 return this;
8486 }
8487
Optimize()8488 Expr *NewExpr::Optimize() { return this; }
8489
Print() const8490 void NewExpr::Print() const { printf("new (%s)", allocType ? allocType->GetString().c_str() : "NULL"); }
8491
EstimateCost() const8492 int NewExpr::EstimateCost() const { return COST_NEW; }
8493