1 /*
2   Copyright (c) 2010-2021, Intel Corporation
3   All rights reserved.
4 
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions are
7   met:
8 
9     * Redistributions of source code must retain the above copyright
10       notice, this list of conditions and the following disclaimer.
11 
12     * Redistributions in binary form must reproduce the above copyright
13       notice, this list of conditions and the following disclaimer in the
14       documentation and/or other materials provided with the distribution.
15 
16     * Neither the name of Intel Corporation nor the names of its
17       contributors may be used to endorse or promote products derived from
18       this software without specific prior written permission.
19 
20 
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
22    IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
23    TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
24    PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
25    OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33 
34 /** @file expr.cpp
35     @brief Implementations of expression classes
36 */
37 
38 #include "expr.h"
39 #include "ast.h"
40 #include "ctx.h"
41 #include "llvmutil.h"
42 #include "module.h"
43 #include "sym.h"
44 #include "type.h"
45 #include "util.h"
46 
47 #ifndef _MSC_VER
48 #include <inttypes.h>
49 #endif
50 #ifndef PRId64
51 #define PRId64 "lld"
52 #endif
53 #ifndef PRIu64
54 #define PRIu64 "llu"
55 #endif
56 
57 #include <list>
58 #include <set>
59 #include <stdio.h>
60 
61 #include <llvm/ExecutionEngine/GenericValue.h>
62 #include <llvm/IR/CallingConv.h>
63 #include <llvm/IR/DerivedTypes.h>
64 #include <llvm/IR/Function.h>
65 #include <llvm/IR/InstIterator.h>
66 #include <llvm/IR/Instructions.h>
67 #include <llvm/IR/LLVMContext.h>
68 #include <llvm/IR/Module.h>
69 #include <llvm/IR/Type.h>
70 
71 using namespace ispc;
72 
73 /////////////////////////////////////////////////////////////////////////////////////
74 // Expr
75 
GetLValue(FunctionEmitContext * ctx) const76 llvm::Value *Expr::GetLValue(FunctionEmitContext *ctx) const {
77     // Expressions that can't provide an lvalue can just return NULL
78     return NULL;
79 }
80 
GetLValueType() const81 const Type *Expr::GetLValueType() const {
82     // This also only needs to be overrided by Exprs that implement the
83     // GetLValue() method.
84     return NULL;
85 }
86 
GetStorageConstant(const Type * type) const87 std::pair<llvm::Constant *, bool> Expr::GetStorageConstant(const Type *type) const { return GetConstant(type); }
GetConstant(const Type * type) const88 std::pair<llvm::Constant *, bool> Expr::GetConstant(const Type *type) const {
89     // The default is failure; just return NULL
90     return std::pair<llvm::Constant *, bool>(NULL, false);
91 }
92 
GetBaseSymbol() const93 Symbol *Expr::GetBaseSymbol() const {
94     // Not all expressions can do this, so provide a generally-useful
95     // default implementation.
96     return NULL;
97 }
98 
HasAmbiguousVariability(std::vector<const Expr * > & warn) const99 bool Expr::HasAmbiguousVariability(std::vector<const Expr *> &warn) const { return false; }
100 
101 #if 0
102 /** If a conversion from 'fromAtomicType' to 'toAtomicType' may cause lost
103     precision, issue a warning.  Don't warn for conversions to bool and
104     conversions between signed and unsigned integers of the same size.
105  */
106 static void
107 lMaybeIssuePrecisionWarning(const AtomicType *toAtomicType,
108                             const AtomicType *fromAtomicType,
109                             SourcePos pos, const char *errorMsgBase) {
110     switch (toAtomicType->basicType) {
111     case AtomicType::TYPE_BOOL:
112     case AtomicType::TYPE_INT8:
113     case AtomicType::TYPE_UINT8:
114     case AtomicType::TYPE_INT16:
115     case AtomicType::TYPE_UINT16:
116     case AtomicType::TYPE_INT32:
117     case AtomicType::TYPE_UINT32:
118     case AtomicType::TYPE_FLOAT:
119     case AtomicType::TYPE_INT64:
120     case AtomicType::TYPE_UINT64:
121     case AtomicType::TYPE_DOUBLE:
122         if ((int)toAtomicType->basicType < (int)fromAtomicType->basicType &&
123             toAtomicType->basicType != AtomicType::TYPE_BOOL &&
124             !(toAtomicType->basicType == AtomicType::TYPE_INT8 &&
125               fromAtomicType->basicType == AtomicType::TYPE_UINT8) &&
126             !(toAtomicType->basicType == AtomicType::TYPE_INT16 &&
127               fromAtomicType->basicType == AtomicType::TYPE_UINT16) &&
128             !(toAtomicType->basicType == AtomicType::TYPE_INT32 &&
129               fromAtomicType->basicType == AtomicType::TYPE_UINT32) &&
130             !(toAtomicType->basicType == AtomicType::TYPE_INT64 &&
131               fromAtomicType->basicType == AtomicType::TYPE_UINT64))
132             Warning(pos, "Conversion from type \"%s\" to type \"%s\" for %s"
133                     " may lose information.",
134                     fromAtomicType->GetString().c_str(), toAtomicType->GetString().c_str(),
135                     errorMsgBase);
136         break;
137     default:
138         FATAL("logic error in lMaybeIssuePrecisionWarning()");
139     }
140 }
141 #endif
142 
143 ///////////////////////////////////////////////////////////////////////////
144 
lArrayToPointer(Expr * expr)145 static Expr *lArrayToPointer(Expr *expr) {
146     Assert(expr != NULL);
147     AssertPos(expr->pos, CastType<ArrayType>(expr->GetType()));
148 
149     Expr *zero = new ConstExpr(AtomicType::UniformInt32, 0, expr->pos);
150     Expr *index = new IndexExpr(expr, zero, expr->pos);
151     Expr *addr = new AddressOfExpr(index, expr->pos);
152     addr = TypeCheck(addr);
153     Assert(addr != NULL);
154     addr = Optimize(addr);
155     Assert(addr != NULL);
156     return addr;
157 }
158 
lIsAllIntZeros(Expr * expr)159 static bool lIsAllIntZeros(Expr *expr) {
160     const Type *type = expr->GetType();
161     if (type == NULL || type->IsIntType() == false)
162         return false;
163 
164     ConstExpr *ce = llvm::dyn_cast<ConstExpr>(expr);
165     if (ce == NULL)
166         return false;
167 
168     uint64_t vals[ISPC_MAX_NVEC];
169     int count = ce->GetValues(vals);
170     if (count == 1)
171         return (vals[0] == 0);
172     else {
173         for (int i = 0; i < count; ++i)
174             if (vals[i] != 0)
175                 return false;
176     }
177     return true;
178 }
179 
lDoTypeConv(const Type * fromType,const Type * toType,Expr ** expr,bool failureOk,const char * errorMsgBase,SourcePos pos)180 static bool lDoTypeConv(const Type *fromType, const Type *toType, Expr **expr, bool failureOk, const char *errorMsgBase,
181                         SourcePos pos) {
182     /* This function is way too long and complex.  Is type conversion stuff
183        always this messy, or can this be cleaned up somehow? */
184     AssertPos(pos, failureOk || errorMsgBase != NULL);
185 
186     if (toType == NULL || fromType == NULL)
187         return false;
188 
189     // The types are equal; there's nothing to do
190     if (Type::Equal(toType, fromType))
191         return true;
192 
193     if (fromType->IsVoidType()) {
194         if (!failureOk)
195             Error(pos, "Can't convert from \"void\" to \"%s\" for %s.", toType->GetString().c_str(), errorMsgBase);
196         return false;
197     }
198 
199     if (toType->IsVoidType()) {
200         if (!failureOk)
201             Error(pos, "Can't convert type \"%s\" to \"void\" for %s.", fromType->GetString().c_str(), errorMsgBase);
202         return false;
203     }
204 
205     if (CastType<FunctionType>(fromType)) {
206         if (CastType<PointerType>(toType) != NULL) {
207             // Convert function type to pointer to function type
208             if (expr != NULL) {
209                 Expr *aoe = new AddressOfExpr(*expr, (*expr)->pos);
210                 if (lDoTypeConv(aoe->GetType(), toType, &aoe, failureOk, errorMsgBase, pos)) {
211                     *expr = aoe;
212                     return true;
213                 }
214             } else
215                 return lDoTypeConv(PointerType::GetUniform(fromType), toType, NULL, failureOk, errorMsgBase, pos);
216         } else {
217             if (!failureOk)
218                 Error(pos, "Can't convert function type \"%s\" to \"%s\" for %s.", fromType->GetString().c_str(),
219                       toType->GetString().c_str(), errorMsgBase);
220             return false;
221         }
222     }
223     if (CastType<FunctionType>(toType)) {
224         if (!failureOk)
225             Error(pos,
226                   "Can't convert from type \"%s\" to function type \"%s\" "
227                   "for %s.",
228                   fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
229         return false;
230     }
231 
232     if ((toType->GetSOAWidth() > 0 || fromType->GetSOAWidth() > 0) &&
233         Type::Equal(toType->GetAsUniformType(), fromType->GetAsUniformType()) &&
234         toType->GetSOAWidth() != fromType->GetSOAWidth()) {
235         if (!failureOk)
236             Error(pos,
237                   "Can't convert between types \"%s\" and \"%s\" with "
238                   "different SOA widths for %s.",
239                   fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
240         return false;
241     }
242 
243     const ArrayType *toArrayType = CastType<ArrayType>(toType);
244     const ArrayType *fromArrayType = CastType<ArrayType>(fromType);
245     const VectorType *toVectorType = CastType<VectorType>(toType);
246     const VectorType *fromVectorType = CastType<VectorType>(fromType);
247     const StructType *toStructType = CastType<StructType>(toType);
248     const StructType *fromStructType = CastType<StructType>(fromType);
249     const EnumType *toEnumType = CastType<EnumType>(toType);
250     const EnumType *fromEnumType = CastType<EnumType>(fromType);
251     const AtomicType *toAtomicType = CastType<AtomicType>(toType);
252     const AtomicType *fromAtomicType = CastType<AtomicType>(fromType);
253     const PointerType *fromPointerType = CastType<PointerType>(fromType);
254     const PointerType *toPointerType = CastType<PointerType>(toType);
255 
256     // Do this early, since for the case of a conversion like
257     // "float foo[10]" -> "float * uniform foo", we have what's seemingly
258     // a varying to uniform conversion (but not really)
259     if (fromArrayType != NULL && toPointerType != NULL) {
260         // can convert any array to a void pointer (both uniform and
261         // varying).
262         if (PointerType::IsVoidPointer(toPointerType))
263             goto typecast_ok;
264 
265         // array to pointer to array element type
266         const Type *eltType = fromArrayType->GetElementType();
267         if (toPointerType->GetBaseType()->IsConstType())
268             eltType = eltType->GetAsConstType();
269 
270         PointerType pt(eltType, toPointerType->GetVariability(), toPointerType->IsConstType());
271         if (Type::Equal(toPointerType, &pt))
272             goto typecast_ok;
273         else {
274             if (!failureOk)
275                 Error(pos,
276                       "Can't convert from incompatible array type \"%s\" "
277                       "to pointer type \"%s\" for %s.",
278                       fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
279             return false;
280         }
281     }
282 
283     if (toType->IsUniformType() && fromType->IsVaryingType()) {
284         if (!failureOk)
285             Error(pos, "Can't convert from type \"%s\" to type \"%s\" for %s.", fromType->GetString().c_str(),
286                   toType->GetString().c_str(), errorMsgBase);
287         return false;
288     }
289 
290     if (fromPointerType != NULL) {
291         if (CastType<AtomicType>(toType) != NULL && toType->IsBoolType())
292             // Allow implicit conversion of pointers to bools
293             goto typecast_ok;
294 
295         if (toArrayType != NULL && Type::Equal(fromType->GetBaseType(), toArrayType->GetElementType())) {
296             // Can convert pointers to arrays of the same type
297             goto typecast_ok;
298         }
299         if (toPointerType == NULL) {
300             if (!failureOk)
301                 Error(pos,
302                       "Can't convert between from pointer type "
303                       "\"%s\" to non-pointer type \"%s\" for %s.",
304                       fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
305             return false;
306         } else if (fromPointerType->IsSlice() == true && toPointerType->IsSlice() == false) {
307             if (!failureOk)
308                 Error(pos,
309                       "Can't convert from pointer to SOA type "
310                       "\"%s\" to pointer to non-SOA type \"%s\" for %s.",
311                       fromPointerType->GetAsNonSlice()->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
312             return false;
313         } else if (PointerType::IsVoidPointer(toPointerType)) {
314             if (fromPointerType->GetBaseType()->IsConstType() && !(toPointerType->GetBaseType()->IsConstType())) {
315                 if (!failureOk)
316                     Error(pos, "Can't convert pointer to const \"%s\" to void pointer.",
317                           fromPointerType->GetString().c_str());
318                 return false;
319             }
320             // any pointer type can be converted to a void *
321             // ...almost. #731
322             goto typecast_ok;
323         } else if (PointerType::IsVoidPointer(fromPointerType) && expr != NULL &&
324                    llvm::dyn_cast<NullPointerExpr>(*expr) != NULL) {
325             // and a NULL convert to any other pointer type
326             goto typecast_ok;
327         } else if (!Type::Equal(fromPointerType->GetBaseType(), toPointerType->GetBaseType()) &&
328                    !Type::Equal(fromPointerType->GetBaseType()->GetAsConstType(), toPointerType->GetBaseType())) {
329             // for const * -> * conversion, print warning.
330             if (Type::EqualIgnoringConst(fromPointerType->GetBaseType(), toPointerType->GetBaseType())) {
331                 if (!Type::Equal(fromPointerType->GetBaseType()->GetAsConstType(), toPointerType->GetBaseType())) {
332                     Warning(pos,
333                             "Converting from const pointer type \"%s\" to "
334                             "pointer type \"%s\" for %s discards const qualifier.",
335                             fromPointerType->GetString().c_str(), toPointerType->GetString().c_str(), errorMsgBase);
336                 }
337             } else {
338                 if (!failureOk) {
339                     Error(pos,
340                           "Can't convert from pointer type \"%s\" to "
341                           "incompatible pointer type \"%s\" for %s.",
342                           fromPointerType->GetString().c_str(), toPointerType->GetString().c_str(), errorMsgBase);
343                 }
344                 return false;
345             }
346         }
347 
348         if (toType->IsVaryingType() && fromType->IsUniformType())
349             goto typecast_ok;
350 
351         if (toPointerType->IsSlice() == true && fromPointerType->IsSlice() == false)
352             goto typecast_ok;
353 
354         // Otherwise there's nothing to do
355         return true;
356     }
357 
358     if (toPointerType != NULL && fromAtomicType != NULL && fromAtomicType->IsIntType() && expr != NULL &&
359         lIsAllIntZeros(*expr)) {
360         // We have a zero-valued integer expression, which can also be
361         // treated as a NULL pointer that can be converted to any other
362         // pointer type.
363         Expr *npe = new NullPointerExpr(pos);
364         if (lDoTypeConv(PointerType::Void, toType, &npe, failureOk, errorMsgBase, pos)) {
365             *expr = npe;
366             return true;
367         }
368         return false;
369     }
370 
371     // Need to check this early, since otherwise the [sic] "unbound"
372     // variability of SOA struct types causes things to get messy if that
373     // hasn't been detected...
374     if (toStructType && fromStructType && (toStructType->GetSOAWidth() != fromStructType->GetSOAWidth())) {
375         if (!failureOk)
376             Error(pos,
377                   "Can't convert between incompatible struct types \"%s\" "
378                   "and \"%s\" for %s.",
379                   fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
380         return false;
381     }
382 
383     // Convert from type T -> const T; just return a TypeCast expr, which
384     // can handle this
385     if (Type::EqualIgnoringConst(toType, fromType) && toType->IsConstType() == true && fromType->IsConstType() == false)
386         goto typecast_ok;
387 
388     if (CastType<ReferenceType>(fromType)) {
389         if (CastType<ReferenceType>(toType)) {
390             // Convert from a reference to a type to a const reference to a type;
391             // this is handled by TypeCastExpr
392             if (Type::Equal(toType->GetReferenceTarget(), fromType->GetReferenceTarget()->GetAsConstType()))
393                 goto typecast_ok;
394 
395             const ArrayType *atFrom = CastType<ArrayType>(fromType->GetReferenceTarget());
396             const ArrayType *atTo = CastType<ArrayType>(toType->GetReferenceTarget());
397 
398             if (atFrom != NULL && atTo != NULL && Type::Equal(atFrom->GetElementType(), atTo->GetElementType())) {
399                 goto typecast_ok;
400             } else {
401                 if (!failureOk)
402                     Error(pos,
403                           "Can't convert between incompatible reference types \"%s\" "
404                           "and \"%s\" for %s.",
405                           fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
406                 return false;
407             }
408         } else {
409             // convert from a reference T -> T
410             if (expr != NULL) {
411                 Expr *drExpr = new RefDerefExpr(*expr, pos);
412                 if (lDoTypeConv(drExpr->GetType(), toType, &drExpr, failureOk, errorMsgBase, pos) == true) {
413                     *expr = drExpr;
414                     return true;
415                 }
416                 return false;
417             } else
418                 return lDoTypeConv(fromType->GetReferenceTarget(), toType, NULL, failureOk, errorMsgBase, pos);
419         }
420     } else if (CastType<ReferenceType>(toType)) {
421         // T -> reference T
422         if (expr != NULL) {
423             Expr *rExpr = new ReferenceExpr(*expr, pos);
424             if (lDoTypeConv(rExpr->GetType(), toType, &rExpr, failureOk, errorMsgBase, pos) == true) {
425                 *expr = rExpr;
426                 return true;
427             }
428             return false;
429         } else {
430             ReferenceType rt(fromType);
431             return lDoTypeConv(&rt, toType, NULL, failureOk, errorMsgBase, pos);
432         }
433     } else if (Type::Equal(toType, fromType->GetAsNonConstType()))
434         // convert: const T -> T (as long as T isn't a reference)
435         goto typecast_ok;
436 
437     fromType = fromType->GetReferenceTarget();
438     toType = toType->GetReferenceTarget();
439     if (toArrayType && fromArrayType) {
440         if (Type::Equal(toArrayType->GetElementType(), fromArrayType->GetElementType())) {
441             // the case of different element counts should have returned
442             // successfully earlier, yes??
443             AssertPos(pos, toArrayType->GetElementCount() != fromArrayType->GetElementCount());
444             goto typecast_ok;
445         } else if (Type::Equal(toArrayType->GetElementType(), fromArrayType->GetElementType()->GetAsConstType())) {
446             // T[x] -> const T[x]
447             goto typecast_ok;
448         } else {
449             if (!failureOk)
450                 Error(pos, "Array type \"%s\" can't be converted to type \"%s\" for %s.", fromType->GetString().c_str(),
451                       toType->GetString().c_str(), errorMsgBase);
452             return false;
453         }
454     }
455 
456     if (toVectorType && fromVectorType) {
457         // converting e.g. int<n> -> float<n>
458         if (fromVectorType->GetElementCount() != toVectorType->GetElementCount()) {
459             if (!failureOk)
460                 Error(pos,
461                       "Can't convert between differently sized vector types "
462                       "\"%s\" -> \"%s\" for %s.",
463                       fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
464             return false;
465         }
466         goto typecast_ok;
467     }
468 
469     if (toStructType && fromStructType) {
470         if (!Type::Equal(toStructType->GetAsUniformType()->GetAsConstType(),
471                          fromStructType->GetAsUniformType()->GetAsConstType())) {
472             if (!failureOk)
473                 Error(pos,
474                       "Can't convert between different struct types "
475                       "\"%s\" and \"%s\" for %s.",
476                       fromStructType->GetString().c_str(), toStructType->GetString().c_str(), errorMsgBase);
477             return false;
478         }
479         goto typecast_ok;
480     }
481 
482     if (toEnumType != NULL && fromEnumType != NULL) {
483         // No implicit conversions between different enum types
484         if (!Type::EqualIgnoringConst(toEnumType->GetAsUniformType(), fromEnumType->GetAsUniformType())) {
485             if (!failureOk)
486                 Error(pos,
487                       "Can't convert between different enum types "
488                       "\"%s\" and \"%s\" for %s",
489                       fromEnumType->GetString().c_str(), toEnumType->GetString().c_str(), errorMsgBase);
490             return false;
491         }
492         goto typecast_ok;
493     }
494 
495     // enum -> atomic (integer, generally...) is always ok
496     if (fromEnumType != NULL) {
497         // Cannot convert to anything other than atomic
498         if (toAtomicType == NULL && toVectorType == NULL) {
499             if (!failureOk)
500                 Error(pos,
501                       "Type conversion from \"%s\" to \"%s\" for %s is not "
502                       "possible.",
503                       fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
504             return false;
505         }
506         goto typecast_ok;
507     }
508 
509     // from here on out, the from type can only be atomic something or
510     // other...
511     if (fromAtomicType == NULL) {
512         if (!failureOk)
513             Error(pos,
514                   "Type conversion from \"%s\" to \"%s\" for %s is not "
515                   "possible.",
516                   fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
517         return false;
518     }
519 
520     // scalar -> short-vector conversions
521     if (toVectorType != NULL && (fromType->GetSOAWidth() == toType->GetSOAWidth()))
522         goto typecast_ok;
523 
524     // ok, it better be a scalar->scalar conversion of some sort by now
525     if (toAtomicType == NULL) {
526         if (!failureOk)
527             Error(pos,
528                   "Type conversion from \"%s\" to \"%s\" for %s is "
529                   "not possible",
530                   fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
531         return false;
532     }
533 
534     if (fromType->GetSOAWidth() != toType->GetSOAWidth()) {
535         if (!failureOk)
536             Error(pos,
537                   "Can't convert between types \"%s\" and \"%s\" with "
538                   "different SOA widths for %s.",
539                   fromType->GetString().c_str(), toType->GetString().c_str(), errorMsgBase);
540         return false;
541     }
542 
543 typecast_ok:
544     if (expr != NULL)
545         *expr = new TypeCastExpr(toType, *expr, pos);
546     return true;
547 }
548 
CanConvertTypes(const Type * fromType,const Type * toType,const char * errorMsgBase,SourcePos pos)549 bool ispc::CanConvertTypes(const Type *fromType, const Type *toType, const char *errorMsgBase, SourcePos pos) {
550     return lDoTypeConv(fromType, toType, NULL, errorMsgBase == NULL, errorMsgBase, pos);
551 }
552 
TypeConvertExpr(Expr * expr,const Type * toType,const char * errorMsgBase)553 Expr *ispc::TypeConvertExpr(Expr *expr, const Type *toType, const char *errorMsgBase) {
554     if (expr == NULL)
555         return NULL;
556 
557 #if 0
558     Debug(expr->pos, "type convert %s -> %s.", expr->GetType()->GetString().c_str(),
559           toType->GetString().c_str());
560 #endif
561 
562     const Type *fromType = expr->GetType();
563     Expr *e = expr;
564     if (lDoTypeConv(fromType, toType, &e, false, errorMsgBase, expr->pos))
565         return e;
566     else
567         return NULL;
568 }
569 
PossiblyResolveFunctionOverloads(Expr * expr,const Type * type)570 bool ispc::PossiblyResolveFunctionOverloads(Expr *expr, const Type *type) {
571     FunctionSymbolExpr *fse = NULL;
572     const FunctionType *funcType = NULL;
573     if (CastType<PointerType>(type) != NULL && (funcType = CastType<FunctionType>(type->GetBaseType())) &&
574         (fse = llvm::dyn_cast<FunctionSymbolExpr>(expr)) != NULL) {
575         // We're initializing a function pointer with a function symbol,
576         // which in turn may represent an overloaded function.  So we need
577         // to try to resolve the overload based on the type of the symbol
578         // we're initializing here.
579         std::vector<const Type *> paramTypes;
580         for (int i = 0; i < funcType->GetNumParameters(); ++i)
581             paramTypes.push_back(funcType->GetParameterType(i));
582 
583         if (fse->ResolveOverloads(expr->pos, paramTypes) == false)
584             return false;
585     }
586     return true;
587 }
588 
589 /** Utility routine that emits code to initialize a symbol given an
590     initializer expression.
591 
592     @param ptr       Memory location of storage for the symbol's data
593     @param symName   Name of symbol (used in error messages)
594     @param symType   Type of variable being initialized
595     @param initExpr  Expression for the initializer
596     @param ctx       FunctionEmitContext to use for generating instructions
597     @param pos       Source file position of the variable being initialized
598 */
InitSymbol(llvm::Value * ptr,const Type * symType,Expr * initExpr,FunctionEmitContext * ctx,SourcePos pos)599 void ispc::InitSymbol(llvm::Value *ptr, const Type *symType, Expr *initExpr, FunctionEmitContext *ctx, SourcePos pos) {
600     if (initExpr == NULL)
601         // leave it uninitialized
602         return;
603 
604     // See if we have a constant initializer a this point
605     std::pair<llvm::Constant *, bool> constValPair = initExpr->GetStorageConstant(symType);
606     llvm::Constant *constValue = constValPair.first;
607     if (constValue != NULL) {
608         // It'd be nice if we could just do a StoreInst(constValue, ptr)
609         // at this point, but unfortunately that doesn't generate great
610         // code (e.g. a bunch of scalar moves for a constant array.)  So
611         // instead we'll make a constant static global that holds the
612         // constant value and emit a memcpy to put its value into the
613         // pointer we have.
614         llvm::Type *llvmType = symType->LLVMStorageType(g->ctx);
615         if (llvmType == NULL) {
616             AssertPos(pos, m->errorCount > 0);
617             return;
618         }
619 
620         if (Type::IsBasicType(symType))
621             ctx->StoreInst(constValue, ptr, symType, symType->IsUniformType());
622         else {
623             llvm::Value *constPtr =
624                 new llvm::GlobalVariable(*m->module, llvmType, true /* const */, llvm::GlobalValue::InternalLinkage,
625                                          constValue, "const_initializer");
626             llvm::Value *size = g->target->SizeOf(llvmType, ctx->GetCurrentBasicBlock());
627             ctx->MemcpyInst(ptr, constPtr, size);
628         }
629 
630         return;
631     }
632 
633     // If the initializer is a straight up expression that isn't an
634     // ExprList, then we'll see if we can type convert it to the type of
635     // the variable.
636     if (llvm::dyn_cast<ExprList>(initExpr) == NULL) {
637         if (PossiblyResolveFunctionOverloads(initExpr, symType) == false)
638             return;
639         initExpr = TypeConvertExpr(initExpr, symType, "initializer");
640 
641         if (initExpr == NULL)
642             return;
643 
644         llvm::Value *initializerValue = initExpr->GetValue(ctx);
645         if (initializerValue != NULL)
646             // Bingo; store the value in the variable's storage
647             ctx->StoreInst(initializerValue, ptr, symType, symType->IsUniformType());
648         return;
649     }
650 
651     // Atomic types and enums can be initialized with { ... } initializer
652     // expressions if they have a single element (except for SOA types,
653     // which are handled below).
654     if (symType->IsSOAType() == false && Type::IsBasicType(symType)) {
655         ExprList *elist = llvm::dyn_cast<ExprList>(initExpr);
656         if (elist != NULL) {
657             if (elist->exprs.size() == 1) {
658                 InitSymbol(ptr, symType, elist->exprs[0], ctx, pos);
659                 return;
660             } else if (symType->IsVaryingType() == false) {
661                 Error(initExpr->pos,
662                       "Expression list initializers with "
663                       "multiple values can't be used with type \"%s\".",
664                       symType->GetString().c_str());
665                 return;
666             }
667         } else
668             return;
669     }
670 
671     const ReferenceType *rt = CastType<ReferenceType>(symType);
672     if (rt) {
673         if (!Type::Equal(initExpr->GetType(), rt)) {
674             Error(initExpr->pos,
675                   "Initializer for reference type \"%s\" must have same "
676                   "reference type itself. \"%s\" is incompatible.",
677                   rt->GetString().c_str(), initExpr->GetType()->GetString().c_str());
678             return;
679         }
680 
681         llvm::Value *initializerValue = initExpr->GetValue(ctx);
682         if (initializerValue)
683             ctx->StoreInst(initializerValue, ptr, initExpr->GetType(), symType->IsUniformType());
684         return;
685     }
686 
687     // Handle initiailizers for SOA types as well as for structs, arrays,
688     // and vectors.
689     const CollectionType *collectionType = CastType<CollectionType>(symType);
690     if (collectionType != NULL || symType->IsSOAType() ||
691         (Type::IsBasicType(symType) && symType->IsVaryingType() == true)) {
692         // Make default value equivalent to number of elements for varying
693         int nElements = g->target->getVectorWidth();
694         if (collectionType)
695             nElements = collectionType->GetElementCount();
696         else if (symType->IsSOAType())
697             nElements = symType->GetSOAWidth();
698 
699         std::string name;
700         if (CastType<StructType>(symType) != NULL)
701             name = "struct";
702         else if (CastType<ArrayType>(symType) != NULL)
703             name = "array";
704         else if (CastType<VectorType>(symType) != NULL)
705             name = "vector";
706         else if (symType->IsSOAType() || (Type::IsBasicType(symType) && symType->IsVaryingType() == true))
707             name = symType->GetVariability().GetString();
708         else
709             FATAL("Unexpected CollectionType in InitSymbol()");
710 
711         // There are two cases for initializing these types; either a
712         // single initializer may be provided (float foo[3] = 0;), in which
713         // case all of the elements are initialized to the given value, or
714         // an initializer list may be provided (float foo[3] = { 1,2,3 }),
715         // in which case the elements are initialized with the
716         // corresponding values.
717         ExprList *exprList = llvm::dyn_cast<ExprList>(initExpr);
718         if (exprList != NULL) {
719             // The { ... } case; make sure we have the no more expressions
720             // in the ExprList as we have struct members
721             int nInits = exprList->exprs.size();
722             if (nInits > nElements) {
723                 Error(initExpr->pos,
724                       "Initializer for %s type \"%s\" requires "
725                       "no more than %d values; %d provided.",
726                       name.c_str(), symType->GetString().c_str(), nElements, nInits);
727                 return;
728             } else if ((Type::IsBasicType(symType) && symType->IsVaryingType() == true) && (nInits < nElements)) {
729                 Error(initExpr->pos,
730                       "Initializer for %s type \"%s\" requires "
731                       "%d values; %d provided.",
732                       name.c_str(), symType->GetString().c_str(), nElements, nInits);
733                 return;
734             }
735 
736             // Initialize each element with the corresponding value from
737             // the ExprList
738             for (int i = 0; i < nElements; ++i) {
739                 // For SOA types and varying, the element type is the uniform variant
740                 // of the underlying type
741                 const Type *elementType =
742                     collectionType ? collectionType->GetElementType(i) : symType->GetAsUniformType();
743 
744                 llvm::Value *ep;
745                 if (CastType<StructType>(symType) != NULL)
746                     ep = ctx->AddElementOffset(ptr, i, NULL, "element");
747                 else
748                     ep = ctx->GetElementPtrInst(ptr, LLVMInt32(0), LLVMInt32(i), PointerType::GetUniform(elementType),
749                                                 "gep");
750 
751                 if (i < nInits)
752                     InitSymbol(ep, elementType, exprList->exprs[i], ctx, pos);
753                 else {
754                     // If we don't have enough initializer values, initialize the
755                     // rest as zero.
756                     llvm::Type *llvmType = elementType->LLVMStorageType(g->ctx);
757                     if (llvmType == NULL) {
758                         AssertPos(pos, m->errorCount > 0);
759                         return;
760                     }
761 
762                     llvm::Constant *zeroInit = llvm::Constant::getNullValue(llvmType);
763                     ctx->StoreInst(zeroInit, ep, elementType, elementType->IsUniformType());
764                 }
765             }
766         } else if (collectionType) {
767             Error(initExpr->pos, "Can't assign type \"%s\" to \"%s\".", initExpr->GetType()->GetString().c_str(),
768                   collectionType->GetString().c_str());
769         } else {
770             FATAL("CollectionType is NULL in InitSymbol()");
771         }
772         return;
773     }
774 
775     FATAL("Unexpected Type in InitSymbol()");
776 }
777 
778 ///////////////////////////////////////////////////////////////////////////
779 
780 /** Given an atomic or vector type, this returns a boolean type with the
781     same "shape".  In other words, if the given type is a vector type of
782     three uniform ints, the returned type is a vector type of three uniform
783     bools. */
lMatchingBoolType(const Type * type)784 static const Type *lMatchingBoolType(const Type *type) {
785     bool uniformTest = type->IsUniformType();
786     const AtomicType *boolBase = uniformTest ? AtomicType::UniformBool : AtomicType::VaryingBool;
787     const VectorType *vt = CastType<VectorType>(type);
788     if (vt != NULL)
789         return new VectorType(boolBase, vt->GetElementCount());
790     else {
791         Assert(Type::IsBasicType(type) || type->IsReferenceType());
792         return boolBase;
793     }
794 }
795 
796 ///////////////////////////////////////////////////////////////////////////
797 // UnaryExpr
798 
lLLVMConstantValue(const Type * type,llvm::LLVMContext * ctx,double value)799 static llvm::Constant *lLLVMConstantValue(const Type *type, llvm::LLVMContext *ctx, double value) {
800     const AtomicType *atomicType = CastType<AtomicType>(type);
801     const EnumType *enumType = CastType<EnumType>(type);
802     const VectorType *vectorType = CastType<VectorType>(type);
803     const PointerType *pointerType = CastType<PointerType>(type);
804 
805     // This function is only called with, and only works for atomic, enum,
806     // and vector types.
807     Assert(atomicType != NULL || enumType != NULL || vectorType != NULL || pointerType != NULL);
808 
809     if (atomicType != NULL || enumType != NULL) {
810         // If it's an atomic or enuemrator type, then figure out which of
811         // the llvmutil.h functions to call to get the corresponding
812         // constant and then call it...
813         bool isUniform = type->IsUniformType();
814         AtomicType::BasicType basicType = (enumType != NULL) ? AtomicType::TYPE_UINT32 : atomicType->basicType;
815 
816         switch (basicType) {
817         case AtomicType::TYPE_VOID:
818             FATAL("can't get constant value for void type");
819             return NULL;
820         case AtomicType::TYPE_BOOL:
821             if (isUniform)
822                 return (value != 0.) ? LLVMTrue : LLVMFalse;
823             else
824                 return LLVMBoolVector(value != 0.);
825         case AtomicType::TYPE_INT8: {
826             int i = (int)value;
827             Assert((double)i == value);
828             return isUniform ? LLVMInt8(i) : LLVMInt8Vector(i);
829         }
830         case AtomicType::TYPE_UINT8: {
831             unsigned int i = (unsigned int)value;
832             return isUniform ? LLVMUInt8(i) : LLVMUInt8Vector(i);
833         }
834         case AtomicType::TYPE_INT16: {
835             int i = (int)value;
836             Assert((double)i == value);
837             return isUniform ? LLVMInt16(i) : LLVMInt16Vector(i);
838         }
839         case AtomicType::TYPE_UINT16: {
840             unsigned int i = (unsigned int)value;
841             return isUniform ? LLVMUInt16(i) : LLVMUInt16Vector(i);
842         }
843         case AtomicType::TYPE_INT32: {
844             int i = (int)value;
845             Assert((double)i == value);
846             return isUniform ? LLVMInt32(i) : LLVMInt32Vector(i);
847         }
848         case AtomicType::TYPE_UINT32: {
849             unsigned int i = (unsigned int)value;
850             return isUniform ? LLVMUInt32(i) : LLVMUInt32Vector(i);
851         }
852         case AtomicType::TYPE_FLOAT:
853             return isUniform ? LLVMFloat((float)value) : LLVMFloatVector((float)value);
854         case AtomicType::TYPE_UINT64: {
855             uint64_t i = (uint64_t)value;
856             Assert(value == (int64_t)i);
857             return isUniform ? LLVMUInt64(i) : LLVMUInt64Vector(i);
858         }
859         case AtomicType::TYPE_INT64: {
860             int64_t i = (int64_t)value;
861             Assert((double)i == value);
862             return isUniform ? LLVMInt64(i) : LLVMInt64Vector(i);
863         }
864         case AtomicType::TYPE_DOUBLE:
865             return isUniform ? LLVMDouble(value) : LLVMDoubleVector(value);
866         default:
867             FATAL("logic error in lLLVMConstantValue");
868             return NULL;
869         }
870     } else if (pointerType != NULL) {
871         Assert(value == 0);
872         if (pointerType->IsUniformType())
873             return llvm::Constant::getNullValue(LLVMTypes::VoidPointerType);
874         else
875             return llvm::Constant::getNullValue(LLVMTypes::VoidPointerVectorType);
876     } else {
877         // For vector types, first get the LLVM constant for the basetype with
878         // a recursive call to lLLVMConstantValue().
879         const Type *baseType = vectorType->GetBaseType();
880         llvm::Constant *constElement = lLLVMConstantValue(baseType, ctx, value);
881         llvm::Type *llvmVectorType = vectorType->LLVMType(ctx);
882 
883         // Now create a constant version of the corresponding LLVM type that we
884         // use to represent the VectorType.
885         // FIXME: this is a little ugly in that the fact that ispc represents
886         // uniform VectorTypes as LLVM VectorTypes and varying VectorTypes as
887         // LLVM ArrayTypes leaks into the code here; it feels like this detail
888         // should be better encapsulated?
889         if (baseType->IsUniformType()) {
890 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
891             llvm::FixedVectorType *lvt = llvm::dyn_cast<llvm::FixedVectorType>(llvmVectorType);
892 #else
893             llvm::VectorType *lvt = llvm::dyn_cast<llvm::VectorType>(llvmVectorType);
894 #endif
895             Assert(lvt != NULL);
896             std::vector<llvm::Constant *> vals;
897             for (unsigned int i = 0; i < lvt->getNumElements(); ++i)
898                 vals.push_back(constElement);
899             return llvm::ConstantVector::get(vals);
900         } else {
901             llvm::ArrayType *lat = llvm::dyn_cast<llvm::ArrayType>(llvmVectorType);
902             Assert(lat != NULL);
903             std::vector<llvm::Constant *> vals;
904             for (unsigned int i = 0; i < lat->getNumElements(); ++i)
905                 vals.push_back(constElement);
906             return llvm::ConstantArray::get(lat, vals);
907         }
908     }
909 }
910 
lMaskForSymbol(Symbol * baseSym,FunctionEmitContext * ctx)911 static llvm::Value *lMaskForSymbol(Symbol *baseSym, FunctionEmitContext *ctx) {
912     if (baseSym == NULL)
913         return ctx->GetFullMask();
914 
915     if (CastType<PointerType>(baseSym->type) != NULL || CastType<ReferenceType>(baseSym->type) != NULL)
916         // FIXME: for pointers, we really only want to do this for
917         // dereferencing the pointer, not for things like pointer
918         // arithmetic, when we may be able to use the internal mask,
919         // depending on context...
920         return ctx->GetFullMask();
921 
922     llvm::Value *mask = (baseSym->parentFunction == ctx->GetFunction() && baseSym->storageClass != SC_STATIC)
923                             ? ctx->GetInternalMask()
924                             : ctx->GetFullMask();
925     return mask;
926 }
927 
928 /** Store the result of an assignment to the given location.
929  */
lStoreAssignResult(llvm::Value * value,llvm::Value * ptr,const Type * valueType,const Type * ptrType,FunctionEmitContext * ctx,Symbol * baseSym)930 static void lStoreAssignResult(llvm::Value *value, llvm::Value *ptr, const Type *valueType, const Type *ptrType,
931                                FunctionEmitContext *ctx, Symbol *baseSym) {
932     Assert(baseSym == NULL || baseSym->varyingCFDepth <= ctx->VaryingCFDepth());
933     if (!g->opt.disableMaskedStoreToStore && !g->opt.disableMaskAllOnOptimizations && baseSym != NULL &&
934         baseSym->varyingCFDepth == ctx->VaryingCFDepth() && baseSym->storageClass != SC_STATIC &&
935         CastType<ReferenceType>(baseSym->type) == NULL && CastType<PointerType>(baseSym->type) == NULL) {
936         // If the variable is declared at the same varying control flow
937         // depth as where it's being assigned, then we don't need to do any
938         // masking but can just do the assignment as if all the lanes were
939         // known to be on.  While this may lead to random/garbage values
940         // written into the lanes that are off, by definition they will
941         // never be accessed, since those lanes aren't executing, and won't
942         // be executing at this scope or any other one before the variable
943         // goes out of scope.
944         ctx->StoreInst(value, ptr, LLVMMaskAllOn, valueType, ptrType);
945     } else {
946         ctx->StoreInst(value, ptr, lMaskForSymbol(baseSym, ctx), valueType, ptrType);
947     }
948 }
949 
950 /** Utility routine to emit code to do a {pre,post}-{inc,dec}rement of the
951     given expresion.
952  */
lEmitPrePostIncDec(UnaryExpr::Op op,Expr * expr,SourcePos pos,FunctionEmitContext * ctx)953 static llvm::Value *lEmitPrePostIncDec(UnaryExpr::Op op, Expr *expr, SourcePos pos, FunctionEmitContext *ctx) {
954     const Type *type = expr->GetType();
955     if (type == NULL)
956         return NULL;
957 
958     // Get both the lvalue and the rvalue of the given expression
959     llvm::Value *lvalue = NULL, *rvalue = NULL;
960     const Type *lvalueType = NULL;
961     if (CastType<ReferenceType>(type) != NULL) {
962         lvalueType = type;
963         type = type->GetReferenceTarget();
964         lvalue = expr->GetValue(ctx);
965 
966         RefDerefExpr *deref = new RefDerefExpr(expr, expr->pos);
967         rvalue = deref->GetValue(ctx);
968     } else {
969         lvalue = expr->GetLValue(ctx);
970         lvalueType = expr->GetLValueType();
971         rvalue = expr->GetValue(ctx);
972     }
973 
974     if (lvalue == NULL) {
975         // If we can't get a lvalue, then we have an error here
976         const char *prepost = (op == UnaryExpr::PreInc || op == UnaryExpr::PreDec) ? "pre" : "post";
977         const char *incdec = (op == UnaryExpr::PreInc || op == UnaryExpr::PostInc) ? "increment" : "decrement";
978         Error(pos, "Can't %s-%s non-lvalues.", prepost, incdec);
979         return NULL;
980     }
981 
982     // Emit code to do the appropriate addition/subtraction to the
983     // expression's old value
984     ctx->SetDebugPos(pos);
985     llvm::Value *binop = NULL;
986     int delta = (op == UnaryExpr::PreInc || op == UnaryExpr::PostInc) ? 1 : -1;
987 
988     std::string opName = rvalue->getName().str();
989     if (op == UnaryExpr::PreInc || op == UnaryExpr::PostInc)
990         opName += "_plus1";
991     else
992         opName += "_minus1";
993 
994     if (CastType<PointerType>(type) != NULL) {
995         const Type *incType = type->IsUniformType() ? AtomicType::UniformInt32 : AtomicType::VaryingInt32;
996         llvm::Constant *dval = lLLVMConstantValue(incType, g->ctx, delta);
997         binop = ctx->GetElementPtrInst(rvalue, dval, type, opName.c_str());
998     } else {
999         llvm::Constant *dval = lLLVMConstantValue(type, g->ctx, delta);
1000         if (type->IsFloatType())
1001             binop = ctx->BinaryOperator(llvm::Instruction::FAdd, rvalue, dval, opName.c_str());
1002         else
1003             binop = ctx->BinaryOperator(llvm::Instruction::Add, rvalue, dval, opName.c_str());
1004     }
1005 
1006     // And store the result out to the lvalue
1007     Symbol *baseSym = expr->GetBaseSymbol();
1008     lStoreAssignResult(binop, lvalue, type, lvalueType, ctx, baseSym);
1009 
1010     // And then if it's a pre increment/decrement, return the final
1011     // computed result; otherwise return the previously-grabbed expression
1012     // value.
1013     return (op == UnaryExpr::PreInc || op == UnaryExpr::PreDec) ? binop : rvalue;
1014 }
1015 
1016 /** Utility routine to emit code to negate the given expression.
1017  */
lEmitNegate(Expr * arg,SourcePos pos,FunctionEmitContext * ctx)1018 static llvm::Value *lEmitNegate(Expr *arg, SourcePos pos, FunctionEmitContext *ctx) {
1019     const Type *type = arg->GetType();
1020     llvm::Value *argVal = arg->GetValue(ctx);
1021     if (type == NULL || argVal == NULL)
1022         return NULL;
1023 
1024     // Negate by subtracting from zero...
1025     ctx->SetDebugPos(pos);
1026     if (type->IsFloatType()) {
1027         llvm::Value *zero = llvm::ConstantFP::getZeroValueForNegation(type->LLVMType(g->ctx));
1028         return ctx->BinaryOperator(llvm::Instruction::FSub, zero, argVal, llvm::Twine(argVal->getName()) + "_negate");
1029     } else {
1030         llvm::Value *zero = lLLVMConstantValue(type, g->ctx, 0.);
1031         AssertPos(pos, type->IsIntType());
1032         return ctx->BinaryOperator(llvm::Instruction::Sub, zero, argVal, llvm::Twine(argVal->getName()) + "_negate");
1033     }
1034 }
1035 
UnaryExpr(Op o,Expr * e,SourcePos p)1036 UnaryExpr::UnaryExpr(Op o, Expr *e, SourcePos p) : Expr(p, UnaryExprID), op(o) { expr = e; }
1037 
GetValue(FunctionEmitContext * ctx) const1038 llvm::Value *UnaryExpr::GetValue(FunctionEmitContext *ctx) const {
1039     if (expr == NULL)
1040         return NULL;
1041 
1042     ctx->SetDebugPos(pos);
1043 
1044     switch (op) {
1045     case PreInc:
1046     case PreDec:
1047     case PostInc:
1048     case PostDec:
1049         return lEmitPrePostIncDec(op, expr, pos, ctx);
1050     case Negate:
1051         return lEmitNegate(expr, pos, ctx);
1052     case LogicalNot: {
1053         llvm::Value *argVal = expr->GetValue(ctx);
1054         return ctx->NotOperator(argVal, llvm::Twine(argVal->getName()) + "_logicalnot");
1055     }
1056     case BitNot: {
1057         llvm::Value *argVal = expr->GetValue(ctx);
1058         return ctx->NotOperator(argVal, llvm::Twine(argVal->getName()) + "_bitnot");
1059     }
1060     default:
1061         FATAL("logic error");
1062         return NULL;
1063     }
1064 }
1065 
GetType() const1066 const Type *UnaryExpr::GetType() const {
1067     if (expr == NULL)
1068         return NULL;
1069 
1070     const Type *type = expr->GetType();
1071     if (type == NULL)
1072         return NULL;
1073 
1074     // For all unary expressions besides logical not, the returned type is
1075     // the same as the source type.  Logical not always returns a bool
1076     // type, with the same shape as the input type.
1077     switch (op) {
1078     case PreInc:
1079     case PreDec:
1080     case PostInc:
1081     case PostDec:
1082     case Negate:
1083     case BitNot:
1084         return type;
1085     case LogicalNot:
1086         return lMatchingBoolType(type);
1087     default:
1088         FATAL("error");
1089         return NULL;
1090     }
1091 }
1092 
lOptimizeBitNot(ConstExpr * constExpr,const Type * type,SourcePos pos)1093 template <typename T> static Expr *lOptimizeBitNot(ConstExpr *constExpr, const Type *type, SourcePos pos) {
1094     T v[ISPC_MAX_NVEC];
1095     int count = constExpr->GetValues(v);
1096     for (int i = 0; i < count; ++i)
1097         v[i] = ~v[i];
1098     return new ConstExpr(type, v, pos);
1099 }
1100 
lOptimizeNegate(ConstExpr * constExpr,const Type * type,SourcePos pos)1101 template <typename T> static Expr *lOptimizeNegate(ConstExpr *constExpr, const Type *type, SourcePos pos) {
1102     T v[ISPC_MAX_NVEC];
1103     int count = constExpr->GetValues(v);
1104     for (int i = 0; i < count; ++i)
1105         v[i] = -v[i];
1106     return new ConstExpr(type, v, pos);
1107 }
1108 
Optimize()1109 Expr *UnaryExpr::Optimize() {
1110     ConstExpr *constExpr = llvm::dyn_cast<ConstExpr>(expr);
1111     // If the operand isn't a constant, then we can't do any optimization
1112     // here...
1113     if (constExpr == NULL)
1114         return this;
1115 
1116     const Type *type = constExpr->GetType();
1117     bool isEnumType = CastType<EnumType>(type) != NULL;
1118 
1119     switch (op) {
1120     case PreInc:
1121     case PreDec:
1122     case PostInc:
1123     case PostDec:
1124         // this shouldn't happen--it's illegal to modify a contant value..
1125         // An error will be issued elsewhere...
1126         return this;
1127     case Negate: {
1128         if (Type::EqualIgnoringConst(type, AtomicType::UniformInt64) ||
1129             Type::EqualIgnoringConst(type, AtomicType::VaryingInt64)) {
1130             return lOptimizeNegate<int64_t>(constExpr, type, pos);
1131         } else if (Type::EqualIgnoringConst(type, AtomicType::UniformUInt64) ||
1132                    Type::EqualIgnoringConst(type, AtomicType::VaryingUInt64)) {
1133             return lOptimizeNegate<uint64_t>(constExpr, type, pos);
1134         } else if (Type::EqualIgnoringConst(type, AtomicType::UniformInt32) ||
1135                    Type::EqualIgnoringConst(type, AtomicType::VaryingInt32)) {
1136             return lOptimizeNegate<int32_t>(constExpr, type, pos);
1137         } else if (Type::EqualIgnoringConst(type, AtomicType::UniformUInt32) ||
1138                    Type::EqualIgnoringConst(type, AtomicType::VaryingUInt32)) {
1139             return lOptimizeNegate<uint32_t>(constExpr, type, pos);
1140         } else if (Type::EqualIgnoringConst(type, AtomicType::UniformInt16) ||
1141                    Type::EqualIgnoringConst(type, AtomicType::VaryingInt16)) {
1142             return lOptimizeNegate<int16_t>(constExpr, type, pos);
1143         } else if (Type::EqualIgnoringConst(type, AtomicType::UniformUInt16) ||
1144                    Type::EqualIgnoringConst(type, AtomicType::VaryingUInt32)) {
1145             return lOptimizeNegate<uint16_t>(constExpr, type, pos);
1146         } else if (Type::EqualIgnoringConst(type, AtomicType::UniformInt8) ||
1147                    Type::EqualIgnoringConst(type, AtomicType::VaryingInt8)) {
1148             return lOptimizeNegate<int8_t>(constExpr, type, pos);
1149         } else if (Type::EqualIgnoringConst(type, AtomicType::UniformUInt8) ||
1150                    Type::EqualIgnoringConst(type, AtomicType::VaryingUInt8)) {
1151             return lOptimizeNegate<uint8_t>(constExpr, type, pos);
1152         } else {
1153             // For all the other types, it's safe to stuff whatever we have
1154             // into a double, do the negate as a double, and then return a
1155             // ConstExpr with the same type as the original...
1156             double v[ISPC_MAX_NVEC];
1157             int count = constExpr->GetValues(v);
1158             for (int i = 0; i < count; ++i)
1159                 v[i] = -v[i];
1160             return new ConstExpr(constExpr, v);
1161         }
1162     }
1163     case BitNot: {
1164         if (Type::EqualIgnoringConst(type, AtomicType::UniformInt8) ||
1165             Type::EqualIgnoringConst(type, AtomicType::VaryingInt8)) {
1166             return lOptimizeBitNot<int8_t>(constExpr, type, pos);
1167         } else if (Type::EqualIgnoringConst(type, AtomicType::UniformUInt8) ||
1168                    Type::EqualIgnoringConst(type, AtomicType::VaryingUInt8)) {
1169             return lOptimizeBitNot<uint8_t>(constExpr, type, pos);
1170         } else if (Type::EqualIgnoringConst(type, AtomicType::UniformInt16) ||
1171                    Type::EqualIgnoringConst(type, AtomicType::VaryingInt16)) {
1172             return lOptimizeBitNot<int16_t>(constExpr, type, pos);
1173         } else if (Type::EqualIgnoringConst(type, AtomicType::UniformUInt16) ||
1174                    Type::EqualIgnoringConst(type, AtomicType::VaryingUInt16)) {
1175             return lOptimizeBitNot<uint16_t>(constExpr, type, pos);
1176         } else if (Type::EqualIgnoringConst(type, AtomicType::UniformInt32) ||
1177                    Type::EqualIgnoringConst(type, AtomicType::VaryingInt32)) {
1178             return lOptimizeBitNot<int32_t>(constExpr, type, pos);
1179         } else if (Type::EqualIgnoringConst(type, AtomicType::UniformUInt32) ||
1180                    Type::EqualIgnoringConst(type, AtomicType::VaryingUInt32) || isEnumType == true) {
1181             return lOptimizeBitNot<uint32_t>(constExpr, type, pos);
1182         } else if (Type::EqualIgnoringConst(type, AtomicType::UniformInt64) ||
1183                    Type::EqualIgnoringConst(type, AtomicType::VaryingInt64)) {
1184             return lOptimizeBitNot<int64_t>(constExpr, type, pos);
1185         } else if (Type::EqualIgnoringConst(type, AtomicType::UniformUInt64) ||
1186                    Type::EqualIgnoringConst(type, AtomicType::VaryingUInt64)) {
1187             return lOptimizeBitNot<uint64_t>(constExpr, type, pos);
1188         } else
1189             FATAL("unexpected type in UnaryExpr::Optimize() / BitNot case");
1190     }
1191     case LogicalNot: {
1192         AssertPos(pos, Type::EqualIgnoringConst(type, AtomicType::UniformBool) ||
1193                            Type::EqualIgnoringConst(type, AtomicType::VaryingBool));
1194         bool v[ISPC_MAX_NVEC];
1195         int count = constExpr->GetValues(v);
1196         for (int i = 0; i < count; ++i)
1197             v[i] = !v[i];
1198         return new ConstExpr(type, v, pos);
1199     }
1200     default:
1201         FATAL("unexpected op in UnaryExpr::Optimize()");
1202         return NULL;
1203     }
1204 }
1205 
TypeCheck()1206 Expr *UnaryExpr::TypeCheck() {
1207     const Type *type;
1208     if (expr == NULL || (type = expr->GetType()) == NULL)
1209         // something went wrong in type checking...
1210         return NULL;
1211 
1212     if (type->IsSOAType()) {
1213         Error(pos, "Can't apply unary operator to SOA type \"%s\".", type->GetString().c_str());
1214         return NULL;
1215     }
1216 
1217     if (op == PreInc || op == PreDec || op == PostInc || op == PostDec) {
1218         if (type->IsConstType()) {
1219             Error(pos,
1220                   "Can't assign to type \"%s\" on left-hand side of "
1221                   "expression.",
1222                   type->GetString().c_str());
1223             return NULL;
1224         }
1225 
1226         if (type->IsNumericType())
1227             return this;
1228 
1229         const PointerType *pt = CastType<PointerType>(type);
1230         if (pt == NULL) {
1231             Error(expr->pos,
1232                   "Can only pre/post increment numeric and "
1233                   "pointer types, not \"%s\".",
1234                   type->GetString().c_str());
1235             return NULL;
1236         }
1237 
1238         if (PointerType::IsVoidPointer(type)) {
1239             Error(expr->pos, "Illegal to pre/post increment \"%s\" type.", type->GetString().c_str());
1240             return NULL;
1241         }
1242         if (CastType<UndefinedStructType>(pt->GetBaseType())) {
1243             Error(expr->pos,
1244                   "Illegal to pre/post increment pointer to "
1245                   "undefined struct type \"%s\".",
1246                   type->GetString().c_str());
1247             return NULL;
1248         }
1249 
1250         return this;
1251     }
1252 
1253     // don't do this for pre/post increment/decrement
1254     if (CastType<ReferenceType>(type)) {
1255         expr = new RefDerefExpr(expr, pos);
1256         type = expr->GetType();
1257     }
1258 
1259     if (op == Negate) {
1260         if (!type->IsNumericType()) {
1261             Error(expr->pos, "Negate not allowed for non-numeric type \"%s\".", type->GetString().c_str());
1262             return NULL;
1263         }
1264     } else if (op == LogicalNot) {
1265         const Type *boolType = lMatchingBoolType(type);
1266         expr = TypeConvertExpr(expr, boolType, "logical not");
1267         if (expr == NULL)
1268             return NULL;
1269     } else if (op == BitNot) {
1270         if (!type->IsIntType()) {
1271             Error(expr->pos,
1272                   "~ operator can only be used with integer types, "
1273                   "not \"%s\".",
1274                   type->GetString().c_str());
1275             return NULL;
1276         }
1277     }
1278     return this;
1279 }
1280 
EstimateCost() const1281 int UnaryExpr::EstimateCost() const {
1282     if (llvm::dyn_cast<ConstExpr>(expr) != NULL)
1283         return 0;
1284 
1285     return COST_SIMPLE_ARITH_LOGIC_OP;
1286 }
1287 
Print() const1288 void UnaryExpr::Print() const {
1289     if (!expr || !GetType())
1290         return;
1291 
1292     printf("[ %s ] (", GetType()->GetString().c_str());
1293     if (op == PreInc)
1294         printf("++");
1295     if (op == PreDec)
1296         printf("--");
1297     if (op == Negate)
1298         printf("-");
1299     if (op == LogicalNot)
1300         printf("!");
1301     if (op == BitNot)
1302         printf("~");
1303     printf("(");
1304     expr->Print();
1305     printf(")");
1306     if (op == PostInc)
1307         printf("++");
1308     if (op == PostDec)
1309         printf("--");
1310     printf(")");
1311     pos.Print();
1312 }
1313 
1314 ///////////////////////////////////////////////////////////////////////////
1315 // BinaryExpr
1316 
lOpString(BinaryExpr::Op op)1317 static const char *lOpString(BinaryExpr::Op op) {
1318     switch (op) {
1319     case BinaryExpr::Add:
1320         return "+";
1321     case BinaryExpr::Sub:
1322         return "-";
1323     case BinaryExpr::Mul:
1324         return "*";
1325     case BinaryExpr::Div:
1326         return "/";
1327     case BinaryExpr::Mod:
1328         return "%";
1329     case BinaryExpr::Shl:
1330         return "<<";
1331     case BinaryExpr::Shr:
1332         return ">>";
1333     case BinaryExpr::Lt:
1334         return "<";
1335     case BinaryExpr::Gt:
1336         return ">";
1337     case BinaryExpr::Le:
1338         return "<=";
1339     case BinaryExpr::Ge:
1340         return ">=";
1341     case BinaryExpr::Equal:
1342         return "==";
1343     case BinaryExpr::NotEqual:
1344         return "!=";
1345     case BinaryExpr::BitAnd:
1346         return "&";
1347     case BinaryExpr::BitXor:
1348         return "^";
1349     case BinaryExpr::BitOr:
1350         return "|";
1351     case BinaryExpr::LogicalAnd:
1352         return "&&";
1353     case BinaryExpr::LogicalOr:
1354         return "||";
1355     case BinaryExpr::Comma:
1356         return ",";
1357     default:
1358         FATAL("unimplemented case in lOpString()");
1359         return "";
1360     }
1361 }
1362 
1363 /** Utility routine to emit the binary bitwise operator corresponding to
1364     the given BinaryExpr::Op.
1365 */
lEmitBinaryBitOp(BinaryExpr::Op op,llvm::Value * arg0Val,llvm::Value * arg1Val,bool isUnsigned,FunctionEmitContext * ctx)1366 static llvm::Value *lEmitBinaryBitOp(BinaryExpr::Op op, llvm::Value *arg0Val, llvm::Value *arg1Val, bool isUnsigned,
1367                                      FunctionEmitContext *ctx) {
1368     llvm::Instruction::BinaryOps inst;
1369     switch (op) {
1370     case BinaryExpr::Shl:
1371         inst = llvm::Instruction::Shl;
1372         break;
1373     case BinaryExpr::Shr:
1374         if (isUnsigned)
1375             inst = llvm::Instruction::LShr;
1376         else
1377             inst = llvm::Instruction::AShr;
1378         break;
1379     case BinaryExpr::BitAnd:
1380         inst = llvm::Instruction::And;
1381         break;
1382     case BinaryExpr::BitXor:
1383         inst = llvm::Instruction::Xor;
1384         break;
1385     case BinaryExpr::BitOr:
1386         inst = llvm::Instruction::Or;
1387         break;
1388     default:
1389         FATAL("logic error in lEmitBinaryBitOp()");
1390         return NULL;
1391     }
1392 
1393     return ctx->BinaryOperator(inst, arg0Val, arg1Val, "bitop");
1394 }
1395 
lEmitBinaryPointerArith(BinaryExpr::Op op,llvm::Value * value0,llvm::Value * value1,const Type * type0,const Type * type1,FunctionEmitContext * ctx,SourcePos pos)1396 static llvm::Value *lEmitBinaryPointerArith(BinaryExpr::Op op, llvm::Value *value0, llvm::Value *value1,
1397                                             const Type *type0, const Type *type1, FunctionEmitContext *ctx,
1398                                             SourcePos pos) {
1399     const PointerType *ptrType = CastType<PointerType>(type0);
1400     AssertPos(pos, ptrType != NULL);
1401     switch (op) {
1402     case BinaryExpr::Add:
1403         // ptr + integer
1404         return ctx->GetElementPtrInst(value0, value1, ptrType, "ptrmath");
1405         break;
1406     case BinaryExpr::Sub: {
1407         if (CastType<PointerType>(type1) != NULL) {
1408             AssertPos(pos, Type::EqualIgnoringConst(type0, type1));
1409 
1410             if (ptrType->IsSlice()) {
1411                 llvm::Value *p0 = ctx->ExtractInst(value0, 0);
1412                 llvm::Value *p1 = ctx->ExtractInst(value1, 0);
1413                 const Type *majorType = ptrType->GetAsNonSlice();
1414                 llvm::Value *majorDelta = lEmitBinaryPointerArith(op, p0, p1, majorType, majorType, ctx, pos);
1415 
1416                 int soaWidth = ptrType->GetBaseType()->GetSOAWidth();
1417                 AssertPos(pos, soaWidth > 0);
1418                 llvm::Value *soaScale = LLVMIntAsType(soaWidth, majorDelta->getType());
1419 
1420                 llvm::Value *majorScale =
1421                     ctx->BinaryOperator(llvm::Instruction::Mul, majorDelta, soaScale, "major_soa_scaled");
1422 
1423                 llvm::Value *m0 = ctx->ExtractInst(value0, 1);
1424                 llvm::Value *m1 = ctx->ExtractInst(value1, 1);
1425                 llvm::Value *minorDelta = ctx->BinaryOperator(llvm::Instruction::Sub, m0, m1, "minor_soa_delta");
1426 
1427                 ctx->MatchIntegerTypes(&majorScale, &minorDelta);
1428                 return ctx->BinaryOperator(llvm::Instruction::Add, majorScale, minorDelta, "soa_ptrdiff");
1429             }
1430 
1431             // ptr - ptr
1432             if (ptrType->IsUniformType()) {
1433                 value0 = ctx->PtrToIntInst(value0);
1434                 value1 = ctx->PtrToIntInst(value1);
1435             }
1436 
1437             // Compute the difference in bytes
1438             llvm::Value *delta = ctx->BinaryOperator(llvm::Instruction::Sub, value0, value1, "ptr_diff");
1439 
1440             // Now divide by the size of the type that the pointer
1441             // points to in order to return the difference in elements.
1442             llvm::Type *llvmElementType = ptrType->GetBaseType()->LLVMType(g->ctx);
1443             llvm::Value *size = g->target->SizeOf(llvmElementType, ctx->GetCurrentBasicBlock());
1444             if (ptrType->IsVaryingType())
1445                 size = ctx->SmearUniform(size);
1446 
1447             if (g->target->is32Bit() == false && g->opt.force32BitAddressing == true) {
1448                 // If we're doing 32-bit addressing math on a 64-bit
1449                 // target, then trunc the delta down to a 32-bit value.
1450                 // (Thus also matching what will be a 32-bit value
1451                 // returned from SizeOf above.)
1452                 if (ptrType->IsUniformType())
1453                     delta = ctx->TruncInst(delta, LLVMTypes::Int32Type, "trunc_ptr_delta");
1454                 else
1455                     delta = ctx->TruncInst(delta, LLVMTypes::Int32VectorType, "trunc_ptr_delta");
1456             }
1457 
1458             // And now do the actual division
1459             return ctx->BinaryOperator(llvm::Instruction::SDiv, delta, size, "element_diff");
1460         } else {
1461             // ptr - integer
1462             llvm::Value *zero = lLLVMConstantValue(type1, g->ctx, 0.);
1463             llvm::Value *negOffset = ctx->BinaryOperator(llvm::Instruction::Sub, zero, value1, "negate");
1464             // Do a GEP as ptr + -integer
1465             return ctx->GetElementPtrInst(value0, negOffset, ptrType, "ptrmath");
1466         }
1467     }
1468     default:
1469         FATAL("Logic error in lEmitBinaryArith() for pointer type case");
1470         return NULL;
1471     }
1472 }
1473 
1474 /** Utility routine to emit binary arithmetic operator based on the given
1475     BinaryExpr::Op.
1476 */
lEmitBinaryArith(BinaryExpr::Op op,llvm::Value * value0,llvm::Value * value1,const Type * type0,const Type * type1,FunctionEmitContext * ctx,SourcePos pos)1477 static llvm::Value *lEmitBinaryArith(BinaryExpr::Op op, llvm::Value *value0, llvm::Value *value1, const Type *type0,
1478                                      const Type *type1, FunctionEmitContext *ctx, SourcePos pos) {
1479     const PointerType *ptrType = CastType<PointerType>(type0);
1480 
1481     if (ptrType != NULL)
1482         return lEmitBinaryPointerArith(op, value0, value1, type0, type1, ctx, pos);
1483     else {
1484         AssertPos(pos, Type::EqualIgnoringConst(type0, type1));
1485 
1486         llvm::Instruction::BinaryOps inst;
1487         bool isFloatOp = type0->IsFloatType();
1488         bool isUnsignedOp = type0->IsUnsignedType();
1489 
1490         const char *opName = NULL;
1491         switch (op) {
1492         case BinaryExpr::Add:
1493             opName = "add";
1494             inst = isFloatOp ? llvm::Instruction::FAdd : llvm::Instruction::Add;
1495             break;
1496         case BinaryExpr::Sub:
1497             opName = "sub";
1498             inst = isFloatOp ? llvm::Instruction::FSub : llvm::Instruction::Sub;
1499             break;
1500         case BinaryExpr::Mul:
1501             opName = "mul";
1502             inst = isFloatOp ? llvm::Instruction::FMul : llvm::Instruction::Mul;
1503             break;
1504         case BinaryExpr::Div:
1505             opName = "div";
1506             if (type0->IsVaryingType() && !isFloatOp)
1507                 PerformanceWarning(pos, "Division with varying integer types is "
1508                                         "very inefficient.");
1509             inst = isFloatOp ? llvm::Instruction::FDiv
1510                              : (isUnsignedOp ? llvm::Instruction::UDiv : llvm::Instruction::SDiv);
1511             break;
1512         case BinaryExpr::Mod:
1513             opName = "mod";
1514             if (type0->IsVaryingType() && !isFloatOp)
1515                 PerformanceWarning(pos, "Modulus operator with varying types is "
1516                                         "very inefficient.");
1517             inst = isFloatOp ? llvm::Instruction::FRem
1518                              : (isUnsignedOp ? llvm::Instruction::URem : llvm::Instruction::SRem);
1519             break;
1520         default:
1521             FATAL("Invalid op type passed to lEmitBinaryArith()");
1522             return NULL;
1523         }
1524 
1525         return ctx->BinaryOperator(inst, value0, value1,
1526                                    (((llvm::Twine(opName) + "_") + value0->getName()) + "_") + value1->getName());
1527     }
1528 }
1529 
1530 /** Utility routine to emit a binary comparison operator based on the given
1531     BinaryExpr::Op.
1532  */
lEmitBinaryCmp(BinaryExpr::Op op,llvm::Value * e0Val,llvm::Value * e1Val,const Type * type,FunctionEmitContext * ctx,SourcePos pos)1533 static llvm::Value *lEmitBinaryCmp(BinaryExpr::Op op, llvm::Value *e0Val, llvm::Value *e1Val, const Type *type,
1534                                    FunctionEmitContext *ctx, SourcePos pos) {
1535     bool isFloatOp = type->IsFloatType();
1536     bool isUnsignedOp = type->IsUnsignedType();
1537 
1538     llvm::CmpInst::Predicate pred;
1539     const char *opName = NULL;
1540     switch (op) {
1541     case BinaryExpr::Lt:
1542         opName = "less";
1543         pred = isFloatOp ? llvm::CmpInst::FCMP_OLT : (isUnsignedOp ? llvm::CmpInst::ICMP_ULT : llvm::CmpInst::ICMP_SLT);
1544         break;
1545     case BinaryExpr::Gt:
1546         opName = "greater";
1547         pred = isFloatOp ? llvm::CmpInst::FCMP_OGT : (isUnsignedOp ? llvm::CmpInst::ICMP_UGT : llvm::CmpInst::ICMP_SGT);
1548         break;
1549     case BinaryExpr::Le:
1550         opName = "lessequal";
1551         pred = isFloatOp ? llvm::CmpInst::FCMP_OLE : (isUnsignedOp ? llvm::CmpInst::ICMP_ULE : llvm::CmpInst::ICMP_SLE);
1552         break;
1553     case BinaryExpr::Ge:
1554         opName = "greaterequal";
1555         pred = isFloatOp ? llvm::CmpInst::FCMP_OGE : (isUnsignedOp ? llvm::CmpInst::ICMP_UGE : llvm::CmpInst::ICMP_SGE);
1556         break;
1557     case BinaryExpr::Equal:
1558         opName = "equal";
1559         pred = isFloatOp ? llvm::CmpInst::FCMP_OEQ : llvm::CmpInst::ICMP_EQ;
1560         break;
1561     case BinaryExpr::NotEqual:
1562         opName = "notequal";
1563         pred = isFloatOp ? llvm::CmpInst::FCMP_UNE : llvm::CmpInst::ICMP_NE;
1564         break;
1565     default:
1566         FATAL("error in lEmitBinaryCmp()");
1567         return NULL;
1568     }
1569 
1570     llvm::Value *cmp = ctx->CmpInst(isFloatOp ? llvm::Instruction::FCmp : llvm::Instruction::ICmp, pred, e0Val, e1Val,
1571                                     (((llvm::Twine(opName) + "_") + e0Val->getName()) + "_") + e1Val->getName());
1572     // This is a little ugly: CmpInst returns i1 values, but we use vectors
1573     // of i32s for varying bool values; type convert the result here if
1574     // needed.
1575     if (type->IsVaryingType())
1576         cmp = ctx->I1VecToBoolVec(cmp);
1577 
1578     return cmp;
1579 }
1580 
BinaryExpr(Op o,Expr * a,Expr * b,SourcePos p)1581 BinaryExpr::BinaryExpr(Op o, Expr *a, Expr *b, SourcePos p) : Expr(p, BinaryExprID), op(o) {
1582     arg0 = a;
1583     arg1 = b;
1584 }
1585 
lCreateBinaryOperatorCall(const BinaryExpr::Op bop,Expr * a0,Expr * a1,Expr * & op,const SourcePos & sp)1586 bool lCreateBinaryOperatorCall(const BinaryExpr::Op bop, Expr *a0, Expr *a1, Expr *&op, const SourcePos &sp) {
1587     bool abort = false;
1588     if ((a0 == NULL) || (a1 == NULL)) {
1589         return abort;
1590     }
1591     Expr *arg0 = a0;
1592     Expr *arg1 = a1;
1593     const Type *type0 = arg0->GetType();
1594     const Type *type1 = arg1->GetType();
1595 
1596     // If either operand is a reference, dereference it before we move
1597     // forward
1598     if (CastType<ReferenceType>(type0) != NULL) {
1599         arg0 = new RefDerefExpr(arg0, arg0->pos);
1600         type0 = arg0->GetType();
1601     }
1602     if (CastType<ReferenceType>(type1) != NULL) {
1603         arg1 = new RefDerefExpr(arg1, arg1->pos);
1604         type1 = arg1->GetType();
1605     }
1606     if ((type0 == NULL) || (type1 == NULL)) {
1607         return abort;
1608     }
1609     if (CastType<StructType>(type0) != NULL || CastType<StructType>(type1) != NULL) {
1610         std::string opName = std::string("operator") + lOpString(bop);
1611         std::vector<Symbol *> funs;
1612         m->symbolTable->LookupFunction(opName.c_str(), &funs);
1613         if (funs.size() == 0) {
1614             Error(sp, "operator %s(%s, %s) is not defined.", opName.c_str(), (type0->GetString()).c_str(),
1615                   (type1->GetString()).c_str());
1616             abort = true;
1617             return abort;
1618         }
1619         Expr *func = new FunctionSymbolExpr(opName.c_str(), funs, sp);
1620         ExprList *args = new ExprList(sp);
1621         args->exprs.push_back(arg0);
1622         args->exprs.push_back(arg1);
1623         op = new FunctionCallExpr(func, args, sp);
1624         return abort;
1625     }
1626     return abort;
1627 }
1628 
MakeBinaryExpr(BinaryExpr::Op o,Expr * a,Expr * b,SourcePos p)1629 Expr *ispc::MakeBinaryExpr(BinaryExpr::Op o, Expr *a, Expr *b, SourcePos p) {
1630     Expr *op = NULL;
1631     bool abort = lCreateBinaryOperatorCall(o, a, b, op, p);
1632     if (op != NULL) {
1633         return op;
1634     }
1635 
1636     // lCreateBinaryOperatorCall can return NULL for 2 cases:
1637     // 1. When there is an error.
1638     // 2. We have to create a new BinaryExpr.
1639     if (abort) {
1640         AssertPos(p, m->errorCount > 0);
1641         return NULL;
1642     }
1643 
1644     op = new BinaryExpr(o, a, b, p);
1645     return op;
1646 }
1647 
1648 /** Emit code for a && or || logical operator.  In particular, the code
1649     here handles "short-circuit" evaluation, where the second expression
1650     isn't evaluated if the value of the first one determines the value of
1651     the result.
1652 */
lEmitLogicalOp(BinaryExpr::Op op,Expr * arg0,Expr * arg1,FunctionEmitContext * ctx,SourcePos pos)1653 llvm::Value *lEmitLogicalOp(BinaryExpr::Op op, Expr *arg0, Expr *arg1, FunctionEmitContext *ctx, SourcePos pos) {
1654 
1655     const Type *type0 = arg0->GetType(), *type1 = arg1->GetType();
1656     if (type0 == NULL || type1 == NULL) {
1657         AssertPos(pos, m->errorCount > 0);
1658         return NULL;
1659     }
1660 
1661     // There is overhead (branches, etc.), to short-circuiting, so if the
1662     // right side of the expression is a) relatively simple, and b) can be
1663     // safely executed with an all-off execution mask, then we just
1664     // evaluate both sides and then the logical operator in that case.
1665     int threshold =
1666         g->target->isGenXTarget() ? PREDICATE_SAFE_SHORT_CIRC_GENX_STATEMENT_COST : PREDICATE_SAFE_IF_STATEMENT_COST;
1667     bool shortCircuit = (EstimateCost(arg1) > threshold || SafeToRunWithMaskAllOff(arg1) == false);
1668 
1669     // Skip short-circuiting for VectorTypes as well.
1670     if ((shortCircuit == false) || CastType<VectorType>(type0) != NULL || CastType<VectorType>(type1) != NULL) {
1671         // If one of the operands is uniform but the other is varying,
1672         // promote the uniform one to varying
1673         if (type0->IsUniformType() && type1->IsVaryingType()) {
1674             arg0 = TypeConvertExpr(arg0, AtomicType::VaryingBool, lOpString(op));
1675             AssertPos(pos, arg0 != NULL);
1676         }
1677         if (type1->IsUniformType() && type0->IsVaryingType()) {
1678             arg1 = TypeConvertExpr(arg1, AtomicType::VaryingBool, lOpString(op));
1679             AssertPos(pos, arg1 != NULL);
1680         }
1681 
1682         llvm::Value *value0 = arg0->GetValue(ctx);
1683         llvm::Value *value1 = arg1->GetValue(ctx);
1684         if (value0 == NULL || value1 == NULL) {
1685             AssertPos(pos, m->errorCount > 0);
1686             return NULL;
1687         }
1688 
1689         if (op == BinaryExpr::LogicalAnd)
1690             return ctx->BinaryOperator(llvm::Instruction::And, value0, value1, "logical_and");
1691         else {
1692             AssertPos(pos, op == BinaryExpr::LogicalOr);
1693             return ctx->BinaryOperator(llvm::Instruction::Or, value0, value1, "logical_or");
1694         }
1695     }
1696 
1697     // Allocate temporary storage for the return value
1698     const Type *retType = Type::MoreGeneralType(type0, type1, pos, lOpString(op));
1699     if (retType == NULL) {
1700         AssertPos(pos, m->errorCount > 0);
1701         return NULL;
1702     }
1703     llvm::Value *retPtr = ctx->AllocaInst(retType, "logical_op_mem");
1704     llvm::BasicBlock *bbSkipEvalValue1 = ctx->CreateBasicBlock("skip_eval_1", ctx->GetCurrentBasicBlock());
1705     llvm::BasicBlock *bbEvalValue1 = ctx->CreateBasicBlock("eval_1", bbSkipEvalValue1);
1706     llvm::BasicBlock *bbLogicalDone = ctx->CreateBasicBlock("logical_op_done", bbEvalValue1);
1707 
1708     // Evaluate the first operand
1709     llvm::Value *value0 = arg0->GetValue(ctx);
1710     if (value0 == NULL) {
1711         AssertPos(pos, m->errorCount > 0);
1712         return NULL;
1713     }
1714 
1715     if (type0->IsUniformType()) {
1716         // Check to see if the value of the first operand is true or false
1717         llvm::Value *value0True = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ, value0, LLVMTrue);
1718 
1719         if (op == BinaryExpr::LogicalOr) {
1720             // For ||, if value0 is true, then we skip evaluating value1
1721             // entirely.
1722             ctx->BranchInst(bbSkipEvalValue1, bbEvalValue1, value0True);
1723 
1724             // If value0 is true, the complete result is true (either
1725             // uniform or varying)
1726             ctx->SetCurrentBasicBlock(bbSkipEvalValue1);
1727             llvm::Value *trueValue = retType->IsUniformType() ? LLVMTrue : LLVMMaskAllOn;
1728             ctx->StoreInst(trueValue, retPtr, retType, retType->IsUniformType());
1729             ctx->BranchInst(bbLogicalDone);
1730         } else {
1731             AssertPos(pos, op == BinaryExpr::LogicalAnd);
1732 
1733             // Conversely, for &&, if value0 is false, we skip evaluating
1734             // value1.
1735             ctx->BranchInst(bbEvalValue1, bbSkipEvalValue1, value0True);
1736 
1737             // In this case, the complete result is false (again, either a
1738             // uniform or varying false).
1739             ctx->SetCurrentBasicBlock(bbSkipEvalValue1);
1740             llvm::Value *falseValue = retType->IsUniformType() ? LLVMFalse : LLVMMaskAllOff;
1741             ctx->StoreInst(falseValue, retPtr, retType, retType->IsUniformType());
1742             ctx->BranchInst(bbLogicalDone);
1743         }
1744 
1745         // Both || and && are in the same situation if the first operand's
1746         // value didn't resolve the final result: they need to evaluate the
1747         // value of the second operand, which in turn gives the value for
1748         // the full expression.
1749         ctx->SetCurrentBasicBlock(bbEvalValue1);
1750         if (type1->IsUniformType() && retType->IsVaryingType()) {
1751             arg1 = TypeConvertExpr(arg1, AtomicType::VaryingBool, "logical op");
1752             AssertPos(pos, arg1 != NULL);
1753         }
1754 
1755         llvm::Value *value1 = arg1->GetValue(ctx);
1756         if (value1 == NULL) {
1757             AssertPos(pos, m->errorCount > 0);
1758             return NULL;
1759         }
1760         ctx->StoreInst(value1, retPtr, arg1->GetType(), retType->IsUniformType());
1761         ctx->BranchInst(bbLogicalDone);
1762 
1763         // In all cases, we end up at the bbLogicalDone basic block;
1764         // loading the value stored in retPtr in turn gives the overall
1765         // result.
1766         ctx->SetCurrentBasicBlock(bbLogicalDone);
1767         return ctx->LoadInst(retPtr, retType);
1768     } else {
1769         // Otherwise, the first operand is varying...  Save the current
1770         // value of the mask so that we can restore it at the end.
1771         llvm::Value *oldMask = ctx->GetInternalMask();
1772         llvm::Value *oldFullMask = ctx->GetFullMask();
1773 
1774         // Convert the second operand to be varying as well, so that we can
1775         // perform logical vector ops with its value.
1776         if (type1->IsUniformType()) {
1777             arg1 = TypeConvertExpr(arg1, AtomicType::VaryingBool, "logical op");
1778             AssertPos(pos, arg1 != NULL);
1779             type1 = arg1->GetType();
1780         }
1781 
1782         if (op == BinaryExpr::LogicalOr) {
1783             // See if value0 is true for all currently executing
1784             // lanes--i.e. if (value0 & mask) == mask.  If so, we don't
1785             // need to evaluate the second operand of the expression.
1786             llvm::Value *value0AndMask = ctx->BinaryOperator(llvm::Instruction::And, value0, oldFullMask, "op&mask");
1787             llvm::Value *equalsMask = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ, value0AndMask,
1788                                                    oldFullMask, "value0&mask==mask");
1789             equalsMask = ctx->I1VecToBoolVec(equalsMask);
1790             if (!ctx->emitGenXHardwareMask()) {
1791                 llvm::Value *allMatch = ctx->All(equalsMask);
1792                 ctx->BranchInst(bbSkipEvalValue1, bbEvalValue1, allMatch);
1793             } else {
1794                 // If uniform CF is emulated, pass vector value to BranchInst
1795                 ctx->BranchInst(bbSkipEvalValue1, bbEvalValue1, equalsMask);
1796             }
1797 
1798             // value0 is true for all running lanes, so it can be used for
1799             // the final result
1800             ctx->SetCurrentBasicBlock(bbSkipEvalValue1);
1801             ctx->StoreInst(value0, retPtr, arg0->GetType(), retType->IsUniformType());
1802             ctx->BranchInst(bbLogicalDone);
1803 
1804             // Otherwise, we need to valuate arg1. However, first we need
1805             // to set the execution mask to be (oldMask & ~a); in other
1806             // words, only execute the instances where value0 is false.
1807             // For the instances where value0 was true, we need to inhibit
1808             // execution.
1809             ctx->SetCurrentBasicBlock(bbEvalValue1);
1810             ctx->SetInternalMaskAndNot(oldMask, value0);
1811 
1812             llvm::Value *value1 = arg1->GetValue(ctx);
1813             if (value1 == NULL) {
1814                 AssertPos(pos, m->errorCount > 0);
1815                 return NULL;
1816             }
1817 
1818             // We need to compute the result carefully, since vector
1819             // elements that were computed when the corresponding lane was
1820             // disabled have undefined values:
1821             // result = (value0 & old_mask) | (value1 & current_mask)
1822             llvm::Value *value1AndMask =
1823                 ctx->BinaryOperator(llvm::Instruction::And, value1, ctx->GetInternalMask(), "op&mask");
1824             llvm::Value *result = ctx->BinaryOperator(llvm::Instruction::Or, value0AndMask, value1AndMask, "or_result");
1825             ctx->StoreInst(result, retPtr, retType, retType->IsUniformType());
1826             ctx->BranchInst(bbLogicalDone);
1827         } else {
1828             AssertPos(pos, op == BinaryExpr::LogicalAnd);
1829 
1830             // If value0 is false for all currently running lanes, the
1831             // overall result must be false: this corresponds to checking
1832             // if (mask & ~value0) == mask.
1833             llvm::Value *notValue0 = ctx->NotOperator(value0, "not_value0");
1834             llvm::Value *notValue0AndMask =
1835                 ctx->BinaryOperator(llvm::Instruction::And, notValue0, oldFullMask, "not_value0&mask");
1836             llvm::Value *equalsMask = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ, notValue0AndMask,
1837                                                    oldFullMask, "not_value0&mask==mask");
1838             equalsMask = ctx->I1VecToBoolVec(equalsMask);
1839             if (!ctx->emitGenXHardwareMask()) {
1840                 llvm::Value *allMatch = ctx->All(equalsMask);
1841                 ctx->BranchInst(bbSkipEvalValue1, bbEvalValue1, allMatch);
1842             } else {
1843                 // If uniform CF is emulated, pass vector value to BranchInst
1844                 ctx->BranchInst(bbSkipEvalValue1, bbEvalValue1, equalsMask);
1845             }
1846 
1847             // value0 was false for all running lanes, so use its value as
1848             // the overall result.
1849             ctx->SetCurrentBasicBlock(bbSkipEvalValue1);
1850             ctx->StoreInst(value0, retPtr, arg0->GetType(), retType->IsUniformType());
1851             ctx->BranchInst(bbLogicalDone);
1852 
1853             // Otherwise we need to evaluate value1, but again with the
1854             // mask set to only be on for the lanes where value0 was true.
1855             // For the lanes where value0 was false, execution needs to be
1856             // disabled: mask = (mask & value0).
1857             ctx->SetCurrentBasicBlock(bbEvalValue1);
1858             ctx->SetInternalMaskAnd(oldMask, value0);
1859 
1860             llvm::Value *value1 = arg1->GetValue(ctx);
1861             if (value1 == NULL) {
1862                 AssertPos(pos, m->errorCount > 0);
1863                 return NULL;
1864             }
1865 
1866             // And as in the || case, we compute the overall result by
1867             // masking off the valid lanes before we AND them together:
1868             // result = (value0 & old_mask) & (value1 & current_mask)
1869             llvm::Value *value0AndMask = ctx->BinaryOperator(llvm::Instruction::And, value0, oldFullMask, "op&mask");
1870             llvm::Value *value1AndMask =
1871                 ctx->BinaryOperator(llvm::Instruction::And, value1, ctx->GetInternalMask(), "value1&mask");
1872             llvm::Value *result =
1873                 ctx->BinaryOperator(llvm::Instruction::And, value0AndMask, value1AndMask, "or_result");
1874             ctx->StoreInst(result, retPtr, retType, retType->IsUniformType());
1875             ctx->BranchInst(bbLogicalDone);
1876         }
1877 
1878         // And finally we always end up in bbLogicalDone, where we restore
1879         // the old mask and return the computed result
1880         ctx->SetCurrentBasicBlock(bbLogicalDone);
1881         ctx->SetInternalMask(oldMask);
1882         return ctx->LoadInst(retPtr, retType);
1883     }
1884 }
1885 
1886 /* Returns true if shifting right by the given amount will lead to
1887    inefficient code.  (Assumes x86 target.  May also warn inaccurately if
1888    later optimization simplify the shift amount more than we are able to
1889    see at this point.) */
lIsDifficultShiftAmount(Expr * expr)1890 static bool lIsDifficultShiftAmount(Expr *expr) {
1891     // Uniform shifts (of uniform values) are no problem.
1892     if (expr->GetType()->IsVaryingType() == false)
1893         return false;
1894 
1895     ConstExpr *ce = llvm::dyn_cast<ConstExpr>(expr);
1896     if (ce) {
1897         // If the shift is by a constant amount, *and* it's the same amount
1898         // in all vector lanes, we're in good shape.
1899         uint32_t amount[ISPC_MAX_NVEC];
1900         int count = ce->GetValues(amount);
1901         for (int i = 1; i < count; ++i)
1902             if (amount[i] != amount[0])
1903                 return true;
1904         return false;
1905     }
1906 
1907     TypeCastExpr *tce = llvm::dyn_cast<TypeCastExpr>(expr);
1908     if (tce && tce->expr) {
1909         // Finally, if the shift amount is given by a uniform value that's
1910         // been smeared out into a varying, we have the same shift for all
1911         // lanes and are also in good shape.
1912         return (tce->expr->GetType()->IsUniformType() == false);
1913     }
1914 
1915     return true;
1916 }
1917 
HasAmbiguousVariability(std::vector<const Expr * > & warn) const1918 bool BinaryExpr::HasAmbiguousVariability(std::vector<const Expr *> &warn) const {
1919     bool isArg0Amb = false;
1920     bool isArg1Amb = false;
1921     if (arg0 != NULL) {
1922         const Type *type0 = arg0->GetType();
1923         if (arg0->HasAmbiguousVariability(warn)) {
1924             isArg0Amb = true;
1925         } else if ((type0 != NULL) && (type0->IsVaryingType())) {
1926             // If either arg is varying, then the expression is un-ambiguously varying.
1927             return false;
1928         }
1929     }
1930     if (arg1 != NULL) {
1931         const Type *type1 = arg1->GetType();
1932         if (arg1->HasAmbiguousVariability(warn)) {
1933             isArg1Amb = true;
1934         } else if ((type1 != NULL) && (type1->IsVaryingType())) {
1935             // If either arg is varying, then the expression is un-ambiguously varying.
1936             return false;
1937         }
1938     }
1939     if (isArg0Amb || isArg1Amb) {
1940         return true;
1941     }
1942 
1943     return false;
1944 }
1945 
GetValue(FunctionEmitContext * ctx) const1946 llvm::Value *BinaryExpr::GetValue(FunctionEmitContext *ctx) const {
1947     if (!arg0 || !arg1) {
1948         AssertPos(pos, m->errorCount > 0);
1949         return NULL;
1950     }
1951 
1952     // Handle these specially, since we want to short-circuit their evaluation...
1953     if (op == LogicalAnd || op == LogicalOr)
1954         return lEmitLogicalOp(op, arg0, arg1, ctx, pos);
1955 
1956     llvm::Value *value0 = arg0->GetValue(ctx);
1957     llvm::Value *value1 = arg1->GetValue(ctx);
1958     if (value0 == NULL || value1 == NULL) {
1959         AssertPos(pos, m->errorCount > 0);
1960         return NULL;
1961     }
1962 
1963     ctx->SetDebugPos(pos);
1964 
1965     switch (op) {
1966     case Add:
1967     case Sub:
1968     case Mul:
1969     case Div:
1970     case Mod:
1971         return lEmitBinaryArith(op, value0, value1, arg0->GetType(), arg1->GetType(), ctx, pos);
1972     case Lt:
1973     case Gt:
1974     case Le:
1975     case Ge:
1976     case Equal:
1977     case NotEqual:
1978         return lEmitBinaryCmp(op, value0, value1, arg0->GetType(), ctx, pos);
1979     case Shl:
1980     case Shr:
1981     case BitAnd:
1982     case BitXor:
1983     case BitOr: {
1984         if (op == Shr && lIsDifficultShiftAmount(arg1))
1985             PerformanceWarning(pos, "Shift right is inefficient for "
1986                                     "varying shift amounts.");
1987         return lEmitBinaryBitOp(op, value0, value1, arg0->GetType()->IsUnsignedType(), ctx);
1988     }
1989     case Comma:
1990         return value1;
1991     default:
1992         FATAL("logic error");
1993         return NULL;
1994     }
1995 }
1996 
GetType() const1997 const Type *BinaryExpr::GetType() const {
1998     if (arg0 == NULL || arg1 == NULL)
1999         return NULL;
2000 
2001     const Type *type0 = arg0->GetType(), *type1 = arg1->GetType();
2002     if (type0 == NULL || type1 == NULL)
2003         return NULL;
2004 
2005     // If this hits, it means that our TypeCheck() method hasn't been
2006     // called before GetType() was called; adding two pointers is illegal
2007     // and will fail type checking and (int + ptr) should be canonicalized
2008     // into (ptr + int) by type checking.
2009     if (op == Add)
2010         AssertPos(pos, CastType<PointerType>(type1) == NULL);
2011 
2012     if (op == Comma)
2013         return arg1->GetType();
2014 
2015     if (CastType<PointerType>(type0) != NULL) {
2016         if (op == Add)
2017             // ptr + int -> ptr
2018             return type0;
2019         else if (op == Sub) {
2020             if (CastType<PointerType>(type1) != NULL) {
2021                 // ptr - ptr -> ~ptrdiff_t
2022                 const Type *diffType = (g->target->is32Bit() || g->opt.force32BitAddressing) ? AtomicType::UniformInt32
2023                                                                                              : AtomicType::UniformInt64;
2024                 if (type0->IsVaryingType() || type1->IsVaryingType())
2025                     diffType = diffType->GetAsVaryingType();
2026                 return diffType;
2027             } else
2028                 // ptr - int -> ptr
2029                 return type0;
2030         }
2031 
2032         // otherwise fall through for these...
2033         AssertPos(pos, op == Lt || op == Gt || op == Le || op == Ge || op == Equal || op == NotEqual);
2034     }
2035 
2036     const Type *exprType = Type::MoreGeneralType(type0, type1, pos, lOpString(op));
2037     // I don't think that MoreGeneralType should be able to fail after the
2038     // checks done in BinaryExpr::TypeCheck().
2039     AssertPos(pos, exprType != NULL);
2040 
2041     switch (op) {
2042     case Add:
2043     case Sub:
2044     case Mul:
2045     case Div:
2046     case Mod:
2047         return exprType;
2048     case Lt:
2049     case Gt:
2050     case Le:
2051     case Ge:
2052     case Equal:
2053     case NotEqual:
2054     case LogicalAnd:
2055     case LogicalOr:
2056         return lMatchingBoolType(exprType);
2057     case Shl:
2058     case Shr:
2059         return type1->IsVaryingType() ? type0->GetAsVaryingType() : type0;
2060     case BitAnd:
2061     case BitXor:
2062     case BitOr:
2063         return exprType;
2064     case Comma:
2065         // handled above, so fall through here just in case
2066     default:
2067         FATAL("logic error in BinaryExpr::GetType()");
2068         return NULL;
2069     }
2070 }
2071 
2072 #define FOLD_OP(O, E)                                                                                                  \
2073     case O:                                                                                                            \
2074         for (int i = 0; i < count; ++i)                                                                                \
2075             result[i] = (v0[i] E v1[i]);                                                                               \
2076         break
2077 
2078 #define FOLD_OP_REF(O, E, TRef)                                                                                        \
2079     case O:                                                                                                            \
2080         for (int i = 0; i < count; ++i) {                                                                              \
2081             result[i] = (v0[i] E v1[i]);                                                                               \
2082             TRef r = (TRef)v0[i] E(TRef) v1[i];                                                                        \
2083             if (result[i] != r)                                                                                        \
2084                 Warning(pos, "Binary expression with type \"%s\" can't represent value.",                              \
2085                         carg0->GetType()->GetString().c_str());                                                        \
2086         }                                                                                                              \
2087         break
2088 
countLeadingZeros(T val)2089 template <typename T> static int countLeadingZeros(T val) {
2090 
2091     int leadingZeros = 0;
2092     size_t size = sizeof(T) * CHAR_BIT;
2093     T msb = (T)(T(1) << (size - 1));
2094 
2095     while (size--) {
2096         if (msb & val) {
2097             break;
2098         }
2099         msb = msb >> 1;
2100         leadingZeros++;
2101     }
2102     return leadingZeros;
2103 }
2104 
2105 /** Constant fold the binary integer operations that aren't also applicable
2106     to floating-point types.
2107 */
2108 template <typename T, typename TRef>
lConstFoldBinaryIntOp(BinaryExpr::Op op,const T * v0,const T * v1,ConstExpr * carg0,SourcePos pos)2109 static ConstExpr *lConstFoldBinaryIntOp(BinaryExpr::Op op, const T *v0, const T *v1, ConstExpr *carg0, SourcePos pos) {
2110     T result[ISPC_MAX_NVEC];
2111     int count = carg0->Count();
2112 
2113     switch (op) {
2114         FOLD_OP_REF(BinaryExpr::Shr, >>, TRef);
2115         FOLD_OP_REF(BinaryExpr::BitAnd, &, TRef);
2116         FOLD_OP_REF(BinaryExpr::BitXor, ^, TRef);
2117         FOLD_OP_REF(BinaryExpr::BitOr, |, TRef);
2118 
2119     case BinaryExpr::Shl:
2120         for (int i = 0; i < count; ++i) {
2121             result[i] = (T(v0[i]) << v1[i]);
2122             if (v1[i] > countLeadingZeros(v0[i])) {
2123                 Warning(pos, "Binary expression with type \"%s\" can't represent value.",
2124                         carg0->GetType()->GetString().c_str());
2125             }
2126         }
2127         break;
2128     case BinaryExpr::Mod:
2129         for (int i = 0; i < count; ++i) {
2130             if (v1[i] == 0) {
2131                 Warning(pos, "Remainder by zero is undefined.");
2132                 return NULL;
2133             } else {
2134                 result[i] = (v0[i] % v1[i]);
2135             }
2136         }
2137         break;
2138     default:
2139         return NULL;
2140     }
2141 
2142     return new ConstExpr(carg0->GetType(), result, carg0->pos);
2143 }
2144 
2145 /** Constant fold the binary logical ops.
2146  */
2147 template <typename T>
lConstFoldBinaryLogicalOp(BinaryExpr::Op op,const T * v0,const T * v1,ConstExpr * carg0)2148 static ConstExpr *lConstFoldBinaryLogicalOp(BinaryExpr::Op op, const T *v0, const T *v1, ConstExpr *carg0) {
2149     bool result[ISPC_MAX_NVEC];
2150     int count = carg0->Count();
2151 
2152     switch (op) {
2153         FOLD_OP(BinaryExpr::Lt, <);
2154         FOLD_OP(BinaryExpr::Gt, >);
2155         FOLD_OP(BinaryExpr::Le, <=);
2156         FOLD_OP(BinaryExpr::Ge, >=);
2157         FOLD_OP(BinaryExpr::Equal, ==);
2158         FOLD_OP(BinaryExpr::NotEqual, !=);
2159         FOLD_OP(BinaryExpr::LogicalAnd, &&);
2160         FOLD_OP(BinaryExpr::LogicalOr, ||);
2161     default:
2162         return NULL;
2163     }
2164 
2165     const Type *rType = carg0->GetType()->IsUniformType() ? AtomicType::UniformBool : AtomicType::VaryingBool;
2166     return new ConstExpr(rType, result, carg0->pos);
2167 }
2168 
2169 /** Constant fold binary arithmetic ops.
2170  */
2171 template <typename T, typename TRef>
lConstFoldBinaryArithOp(BinaryExpr::Op op,const T * v0,const T * v1,ConstExpr * carg0,SourcePos pos)2172 static ConstExpr *lConstFoldBinaryArithOp(BinaryExpr::Op op, const T *v0, const T *v1, ConstExpr *carg0,
2173                                           SourcePos pos) {
2174     T result[ISPC_MAX_NVEC];
2175     int count = carg0->Count();
2176 
2177     switch (op) {
2178         FOLD_OP_REF(BinaryExpr::Add, +, TRef);
2179         FOLD_OP_REF(BinaryExpr::Sub, -, TRef);
2180         FOLD_OP_REF(BinaryExpr::Mul, *, TRef);
2181     case BinaryExpr::Div:
2182         for (int i = 0; i < count; ++i) {
2183             if (v1[i] == 0) {
2184                 Warning(pos, "Division by zero is undefined.");
2185                 return NULL;
2186             } else {
2187                 result[i] = (v0[i] / v1[i]);
2188             }
2189         }
2190         break;
2191     default:
2192         return NULL;
2193     }
2194 
2195     return new ConstExpr(carg0->GetType(), result, carg0->pos);
2196 }
2197 
2198 /** Constant fold the various boolean binary ops.
2199  */
lConstFoldBoolBinaryOp(BinaryExpr::Op op,const bool * v0,const bool * v1,ConstExpr * carg0)2200 static ConstExpr *lConstFoldBoolBinaryOp(BinaryExpr::Op op, const bool *v0, const bool *v1, ConstExpr *carg0) {
2201     bool result[ISPC_MAX_NVEC];
2202     int count = carg0->Count();
2203 
2204     switch (op) {
2205         FOLD_OP(BinaryExpr::BitAnd, &);
2206         FOLD_OP(BinaryExpr::BitXor, ^);
2207         FOLD_OP(BinaryExpr::BitOr, |);
2208         FOLD_OP(BinaryExpr::Lt, <);
2209         FOLD_OP(BinaryExpr::Gt, >);
2210         FOLD_OP(BinaryExpr::Le, <=);
2211         FOLD_OP(BinaryExpr::Ge, >=);
2212         FOLD_OP(BinaryExpr::Equal, ==);
2213         FOLD_OP(BinaryExpr::NotEqual, !=);
2214         FOLD_OP(BinaryExpr::LogicalAnd, &&);
2215         FOLD_OP(BinaryExpr::LogicalOr, ||);
2216     default:
2217         return NULL;
2218     }
2219 
2220     return new ConstExpr(carg0->GetType(), result, carg0->pos);
2221 }
2222 
2223 template <typename T>
lConstFoldBinaryFPOp(ConstExpr * constArg0,ConstExpr * constArg1,BinaryExpr::Op op,BinaryExpr * origExpr,SourcePos pos)2224 static Expr *lConstFoldBinaryFPOp(ConstExpr *constArg0, ConstExpr *constArg1, BinaryExpr::Op op, BinaryExpr *origExpr,
2225                                   SourcePos pos) {
2226     T v0[ISPC_MAX_NVEC], v1[ISPC_MAX_NVEC];
2227     constArg0->GetValues(v0);
2228     constArg1->GetValues(v1);
2229     ConstExpr *ret;
2230     if ((ret = lConstFoldBinaryArithOp<T, T>(op, v0, v1, constArg0, pos)) != NULL)
2231         return ret;
2232     else if ((ret = lConstFoldBinaryLogicalOp(op, v0, v1, constArg0)) != NULL)
2233         return ret;
2234     else
2235         return origExpr;
2236 }
2237 
2238 template <typename T, typename TRef>
lConstFoldBinaryIntOp(ConstExpr * constArg0,ConstExpr * constArg1,BinaryExpr::Op op,BinaryExpr * origExpr,SourcePos pos)2239 static Expr *lConstFoldBinaryIntOp(ConstExpr *constArg0, ConstExpr *constArg1, BinaryExpr::Op op, BinaryExpr *origExpr,
2240                                    SourcePos pos) {
2241     T v0[ISPC_MAX_NVEC], v1[ISPC_MAX_NVEC];
2242     constArg0->GetValues(v0);
2243     constArg1->GetValues(v1);
2244     ConstExpr *ret;
2245     if ((ret = lConstFoldBinaryArithOp<T, TRef>(op, v0, v1, constArg0, pos)) != NULL)
2246         return ret;
2247     else if ((ret = lConstFoldBinaryIntOp<T, TRef>(op, v0, v1, constArg0, pos)) != NULL)
2248         return ret;
2249     else if ((ret = lConstFoldBinaryLogicalOp(op, v0, v1, constArg0)) != NULL)
2250         return ret;
2251     else
2252         return origExpr;
2253 }
2254 
Optimize()2255 Expr *BinaryExpr::Optimize() {
2256     if (arg0 == NULL || arg1 == NULL)
2257         return NULL;
2258 
2259     ConstExpr *constArg0 = llvm::dyn_cast<ConstExpr>(arg0);
2260     ConstExpr *constArg1 = llvm::dyn_cast<ConstExpr>(arg1);
2261 
2262     if (g->opt.fastMath) {
2263         // optimizations related to division by floats..
2264 
2265         // transform x / const -> x * (1/const)
2266         if (op == Div && constArg1 != NULL) {
2267             const Type *type1 = constArg1->GetType();
2268             if (Type::EqualIgnoringConst(type1, AtomicType::UniformFloat) ||
2269                 Type::EqualIgnoringConst(type1, AtomicType::VaryingFloat)) {
2270                 float inv[ISPC_MAX_NVEC];
2271                 int count = constArg1->GetValues(inv);
2272                 for (int i = 0; i < count; ++i)
2273                     inv[i] = 1.f / inv[i];
2274                 Expr *einv = new ConstExpr(type1, inv, constArg1->pos);
2275                 Expr *e = new BinaryExpr(Mul, arg0, einv, pos);
2276                 e = ::TypeCheck(e);
2277                 if (e == NULL)
2278                     return NULL;
2279                 return ::Optimize(e);
2280             }
2281         }
2282 
2283         // transform x / y -> x * rcp(y)
2284         if (op == Div) {
2285             const Type *type1 = arg1->GetType();
2286             if (Type::EqualIgnoringConst(type1, AtomicType::UniformFloat) ||
2287                 Type::EqualIgnoringConst(type1, AtomicType::VaryingFloat)) {
2288                 // Get the symbol for the appropriate builtin
2289                 std::vector<Symbol *> rcpFuns;
2290                 m->symbolTable->LookupFunction("rcp", &rcpFuns);
2291                 if (rcpFuns.size() > 0) {
2292                     Expr *rcpSymExpr = new FunctionSymbolExpr("rcp", rcpFuns, pos);
2293                     ExprList *args = new ExprList(arg1, arg1->pos);
2294                     Expr *rcpCall = new FunctionCallExpr(rcpSymExpr, args, arg1->pos);
2295                     rcpCall = ::TypeCheck(rcpCall);
2296                     if (rcpCall == NULL)
2297                         return NULL;
2298                     rcpCall = ::Optimize(rcpCall);
2299                     if (rcpCall == NULL)
2300                         return NULL;
2301 
2302                     Expr *ret = new BinaryExpr(Mul, arg0, rcpCall, pos);
2303                     ret = ::TypeCheck(ret);
2304                     if (ret == NULL)
2305                         return NULL;
2306                     return ::Optimize(ret);
2307                 } else
2308                     Warning(pos, "rcp() not found from stdlib.  Can't apply "
2309                                  "fast-math rcp optimization.");
2310             }
2311         }
2312     }
2313 
2314     // From here on out, we're just doing constant folding, so if both args
2315     // aren't constants then we're done...
2316     if (constArg0 == NULL || constArg1 == NULL)
2317         return this;
2318 
2319     AssertPos(pos, Type::EqualIgnoringConst(arg0->GetType(), arg1->GetType()));
2320     const Type *type = arg0->GetType()->GetAsNonConstType();
2321     if (Type::Equal(type, AtomicType::UniformFloat) || Type::Equal(type, AtomicType::VaryingFloat)) {
2322         return lConstFoldBinaryFPOp<float>(constArg0, constArg1, op, this, pos);
2323     } else if (Type::Equal(type, AtomicType::UniformDouble) || Type::Equal(type, AtomicType::VaryingDouble)) {
2324         return lConstFoldBinaryFPOp<double>(constArg0, constArg1, op, this, pos);
2325     } else if (Type::Equal(type, AtomicType::UniformInt8) || Type::Equal(type, AtomicType::VaryingInt8)) {
2326         return lConstFoldBinaryIntOp<int8_t, int64_t>(constArg0, constArg1, op, this, pos);
2327     } else if (Type::Equal(type, AtomicType::UniformUInt8) || Type::Equal(type, AtomicType::VaryingUInt8)) {
2328         return lConstFoldBinaryIntOp<uint8_t, uint64_t>(constArg0, constArg1, op, this, pos);
2329     } else if (Type::Equal(type, AtomicType::UniformInt16) || Type::Equal(type, AtomicType::VaryingInt16)) {
2330         return lConstFoldBinaryIntOp<int16_t, int64_t>(constArg0, constArg1, op, this, pos);
2331     } else if (Type::Equal(type, AtomicType::UniformUInt16) || Type::Equal(type, AtomicType::VaryingUInt16)) {
2332         return lConstFoldBinaryIntOp<uint16_t, uint64_t>(constArg0, constArg1, op, this, pos);
2333     } else if (Type::Equal(type, AtomicType::UniformInt32) || Type::Equal(type, AtomicType::VaryingInt32)) {
2334         return lConstFoldBinaryIntOp<int32_t, int64_t>(constArg0, constArg1, op, this, pos);
2335     } else if (Type::Equal(type, AtomicType::UniformUInt32) || Type::Equal(type, AtomicType::VaryingUInt32)) {
2336         return lConstFoldBinaryIntOp<uint32_t, uint64_t>(constArg0, constArg1, op, this, pos);
2337     } else if (Type::Equal(type, AtomicType::UniformInt64) || Type::Equal(type, AtomicType::VaryingInt64)) {
2338         return lConstFoldBinaryIntOp<int64_t, int64_t>(constArg0, constArg1, op, this, pos);
2339     } else if (Type::Equal(type, AtomicType::UniformUInt64) || Type::Equal(type, AtomicType::VaryingUInt64)) {
2340         return lConstFoldBinaryIntOp<uint64_t, uint64_t>(constArg0, constArg1, op, this, pos);
2341     } else if (Type::Equal(type, AtomicType::UniformBool) || Type::Equal(type, AtomicType::VaryingBool)) {
2342         bool v0[ISPC_MAX_NVEC], v1[ISPC_MAX_NVEC];
2343         constArg0->GetValues(v0);
2344         constArg1->GetValues(v1);
2345         ConstExpr *ret;
2346         if ((ret = lConstFoldBoolBinaryOp(op, v0, v1, constArg0)) != NULL)
2347             return ret;
2348         else if ((ret = lConstFoldBinaryLogicalOp(op, v0, v1, constArg0)) != NULL)
2349             return ret;
2350         else
2351             return this;
2352     } else
2353         return this;
2354 }
2355 
TypeCheck()2356 Expr *BinaryExpr::TypeCheck() {
2357     if (arg0 == NULL || arg1 == NULL)
2358         return NULL;
2359 
2360     const Type *type0 = arg0->GetType(), *type1 = arg1->GetType();
2361     if (type0 == NULL || type1 == NULL)
2362         return NULL;
2363 
2364     // If either operand is a reference, dereference it before we move
2365     // forward
2366     if (CastType<ReferenceType>(type0) != NULL) {
2367         arg0 = new RefDerefExpr(arg0, arg0->pos);
2368         type0 = arg0->GetType();
2369         AssertPos(pos, type0 != NULL);
2370     }
2371     if (CastType<ReferenceType>(type1) != NULL) {
2372         arg1 = new RefDerefExpr(arg1, arg1->pos);
2373         type1 = arg1->GetType();
2374         AssertPos(pos, type1 != NULL);
2375     }
2376 
2377     // Convert arrays to pointers to their first elements
2378     if (CastType<ArrayType>(type0) != NULL) {
2379         arg0 = lArrayToPointer(arg0);
2380         type0 = arg0->GetType();
2381     }
2382     if (CastType<ArrayType>(type1) != NULL) {
2383         arg1 = lArrayToPointer(arg1);
2384         type1 = arg1->GetType();
2385     }
2386 
2387     // Prohibit binary operators with SOA types
2388     if (type0->GetSOAWidth() > 0) {
2389         Error(arg0->pos,
2390               "Illegal to use binary operator %s with SOA type "
2391               "\"%s\".",
2392               lOpString(op), type0->GetString().c_str());
2393         return NULL;
2394     }
2395     if (type1->GetSOAWidth() > 0) {
2396         Error(arg1->pos,
2397               "Illegal to use binary operator %s with SOA type "
2398               "\"%s\".",
2399               lOpString(op), type1->GetString().c_str());
2400         return NULL;
2401     }
2402 
2403     const PointerType *pt0 = CastType<PointerType>(type0);
2404     const PointerType *pt1 = CastType<PointerType>(type1);
2405     if (pt0 != NULL && pt1 != NULL && op == Sub) {
2406         // Pointer subtraction
2407         if (PointerType::IsVoidPointer(type0)) {
2408             Error(pos,
2409                   "Illegal to perform pointer arithmetic "
2410                   "on \"%s\" type.",
2411                   type0->GetString().c_str());
2412             return NULL;
2413         }
2414         if (PointerType::IsVoidPointer(type1)) {
2415             Error(pos,
2416                   "Illegal to perform pointer arithmetic "
2417                   "on \"%s\" type.",
2418                   type1->GetString().c_str());
2419             return NULL;
2420         }
2421         if (CastType<UndefinedStructType>(pt0->GetBaseType())) {
2422             Error(pos,
2423                   "Illegal to perform pointer arithmetic "
2424                   "on undefined struct type \"%s\".",
2425                   pt0->GetString().c_str());
2426             return NULL;
2427         }
2428         if (CastType<UndefinedStructType>(pt1->GetBaseType())) {
2429             Error(pos,
2430                   "Illegal to perform pointer arithmetic "
2431                   "on undefined struct type \"%s\".",
2432                   pt1->GetString().c_str());
2433             return NULL;
2434         }
2435 
2436         const Type *t = Type::MoreGeneralType(type0, type1, pos, "-");
2437         if (t == NULL)
2438             return NULL;
2439 
2440         arg0 = TypeConvertExpr(arg0, t, "pointer subtraction");
2441         arg1 = TypeConvertExpr(arg1, t, "pointer subtraction");
2442         if (arg0 == NULL || arg1 == NULL)
2443             return NULL;
2444 
2445         return this;
2446     } else if (((pt0 != NULL || pt1 != NULL) && op == Add) || (pt0 != NULL && op == Sub)) {
2447         // Handle ptr + int, int + ptr, ptr - int
2448         if (pt0 != NULL && pt1 != NULL) {
2449             Error(pos, "Illegal to add two pointer types \"%s\" and \"%s\".", pt0->GetString().c_str(),
2450                   pt1->GetString().c_str());
2451             return NULL;
2452         } else if (pt1 != NULL) {
2453             // put in canonical order with the pointer as the first operand
2454             // for GetValue()
2455             std::swap(arg0, arg1);
2456             std::swap(type0, type1);
2457             std::swap(pt0, pt1);
2458         }
2459 
2460         AssertPos(pos, pt0 != NULL);
2461 
2462         if (PointerType::IsVoidPointer(pt0)) {
2463             Error(pos,
2464                   "Illegal to perform pointer arithmetic "
2465                   "on \"%s\" type.",
2466                   pt0->GetString().c_str());
2467             return NULL;
2468         }
2469         if (CastType<UndefinedStructType>(pt0->GetBaseType())) {
2470             Error(pos,
2471                   "Illegal to perform pointer arithmetic "
2472                   "on undefined struct type \"%s\".",
2473                   pt0->GetString().c_str());
2474             return NULL;
2475         }
2476 
2477         const Type *offsetType = g->target->is32Bit() ? AtomicType::UniformInt32 : AtomicType::UniformInt64;
2478         if (pt0->IsVaryingType())
2479             offsetType = offsetType->GetAsVaryingType();
2480         if (type1->IsVaryingType()) {
2481             arg0 = TypeConvertExpr(arg0, type0->GetAsVaryingType(), "pointer addition");
2482             offsetType = offsetType->GetAsVaryingType();
2483             AssertPos(pos, arg0 != NULL);
2484         }
2485 
2486         arg1 = TypeConvertExpr(arg1, offsetType, lOpString(op));
2487         if (arg1 == NULL)
2488             return NULL;
2489 
2490         return this;
2491     }
2492 
2493     switch (op) {
2494     case Shl:
2495     case Shr:
2496     case BitAnd:
2497     case BitXor:
2498     case BitOr: {
2499         // Must have integer or bool-typed operands for these bit-related
2500         // ops; don't do any implicit conversions from floats here...
2501         if (!type0->IsIntType() && !type0->IsBoolType()) {
2502             Error(arg0->pos,
2503                   "First operand to binary operator \"%s\" must be "
2504                   "an integer or bool.",
2505                   lOpString(op));
2506             return NULL;
2507         }
2508         if (!type1->IsIntType() && !type1->IsBoolType()) {
2509             Error(arg1->pos,
2510                   "Second operand to binary operator \"%s\" must be "
2511                   "an integer or bool.",
2512                   lOpString(op));
2513             return NULL;
2514         }
2515 
2516         if (op == Shl || op == Shr) {
2517             bool isVarying = (type0->IsVaryingType() || type1->IsVaryingType());
2518             if (isVarying) {
2519                 arg0 = TypeConvertExpr(arg0, type0->GetAsVaryingType(), "shift operator");
2520                 if (arg0 == NULL)
2521                     return NULL;
2522                 type0 = arg0->GetType();
2523             }
2524             arg1 = TypeConvertExpr(arg1, type0, "shift operator");
2525             if (arg1 == NULL)
2526                 return NULL;
2527         } else {
2528             const Type *promotedType = Type::MoreGeneralType(type0, type1, arg0->pos, "binary bit op");
2529             if (promotedType == NULL)
2530                 return NULL;
2531 
2532             arg0 = TypeConvertExpr(arg0, promotedType, "binary bit op");
2533             arg1 = TypeConvertExpr(arg1, promotedType, "binary bit op");
2534             if (arg0 == NULL || arg1 == NULL)
2535                 return NULL;
2536         }
2537         return this;
2538     }
2539     case Add:
2540     case Sub:
2541     case Mul:
2542     case Div:
2543     case Mod: {
2544         // Must be numeric type for these.  (And mod is special--can't be float)
2545         if (!type0->IsNumericType() || (op == Mod && type0->IsFloatType())) {
2546             Error(arg0->pos,
2547                   "First operand to binary operator \"%s\" is of "
2548                   "invalid type \"%s\".",
2549                   lOpString(op), type0->GetString().c_str());
2550             return NULL;
2551         }
2552         if (!type1->IsNumericType() || (op == Mod && type1->IsFloatType())) {
2553             Error(arg1->pos,
2554                   "First operand to binary operator \"%s\" is of "
2555                   "invalid type \"%s\".",
2556                   lOpString(op), type1->GetString().c_str());
2557             return NULL;
2558         }
2559 
2560         const Type *promotedType = Type::MoreGeneralType(type0, type1, Union(arg0->pos, arg1->pos), lOpString(op));
2561         if (promotedType == NULL)
2562             return NULL;
2563 
2564         arg0 = TypeConvertExpr(arg0, promotedType, lOpString(op));
2565         arg1 = TypeConvertExpr(arg1, promotedType, lOpString(op));
2566         if (arg0 == NULL || arg1 == NULL)
2567             return NULL;
2568         return this;
2569     }
2570     case Lt:
2571     case Gt:
2572     case Le:
2573     case Ge:
2574     case Equal:
2575     case NotEqual: {
2576         const PointerType *pt0 = CastType<PointerType>(type0);
2577         const PointerType *pt1 = CastType<PointerType>(type1);
2578 
2579         // Convert '0' in expressions where the other expression is a
2580         // pointer type to a NULL pointer.
2581         if (pt0 != NULL && lIsAllIntZeros(arg1)) {
2582             arg1 = new NullPointerExpr(pos);
2583             type1 = arg1->GetType();
2584             pt1 = CastType<PointerType>(type1);
2585         } else if (pt1 != NULL && lIsAllIntZeros(arg0)) {
2586             arg0 = new NullPointerExpr(pos);
2587             type0 = arg1->GetType();
2588             pt0 = CastType<PointerType>(type0);
2589         }
2590 
2591         if (pt0 == NULL && pt1 == NULL) {
2592             if (!type0->IsBoolType() && !type0->IsNumericType()) {
2593                 Error(arg0->pos,
2594                       "First operand to operator \"%s\" is of "
2595                       "non-comparable type \"%s\".",
2596                       lOpString(op), type0->GetString().c_str());
2597                 return NULL;
2598             }
2599             if (!type1->IsBoolType() && !type1->IsNumericType()) {
2600                 Error(arg1->pos,
2601                       "Second operand to operator \"%s\" is of "
2602                       "non-comparable type \"%s\".",
2603                       lOpString(op), type1->GetString().c_str());
2604                 return NULL;
2605             }
2606         }
2607 
2608         const Type *promotedType = Type::MoreGeneralType(type0, type1, arg0->pos, lOpString(op));
2609         if (promotedType == NULL)
2610             return NULL;
2611 
2612         arg0 = TypeConvertExpr(arg0, promotedType, lOpString(op));
2613         arg1 = TypeConvertExpr(arg1, promotedType, lOpString(op));
2614         if (arg0 == NULL || arg1 == NULL)
2615             return NULL;
2616         return this;
2617     }
2618     case LogicalAnd:
2619     case LogicalOr: {
2620         // For now, we just type convert to boolean types, of the same
2621         // variability as the original types.  (When generating code, it's
2622         // useful to have preserved the uniform/varying distinction.)
2623         const AtomicType *boolType0 = type0->IsUniformType() ? AtomicType::UniformBool : AtomicType::VaryingBool;
2624         const AtomicType *boolType1 = type1->IsUniformType() ? AtomicType::UniformBool : AtomicType::VaryingBool;
2625 
2626         const Type *destType0 = NULL, *destType1 = NULL;
2627         const VectorType *vtype0 = CastType<VectorType>(type0);
2628         const VectorType *vtype1 = CastType<VectorType>(type1);
2629         if (vtype0 && vtype1) {
2630             int sz0 = vtype0->GetElementCount(), sz1 = vtype1->GetElementCount();
2631             if (sz0 != sz1) {
2632                 Error(pos,
2633                       "Can't do logical operation \"%s\" between vector types of "
2634                       "different sizes (%d vs. %d).",
2635                       lOpString(op), sz0, sz1);
2636                 return NULL;
2637             }
2638             destType0 = new VectorType(boolType0, sz0);
2639             destType1 = new VectorType(boolType1, sz1);
2640         } else if (vtype0 != NULL) {
2641             destType0 = new VectorType(boolType0, vtype0->GetElementCount());
2642             destType1 = new VectorType(boolType1, vtype0->GetElementCount());
2643         } else if (vtype1 != NULL) {
2644             destType0 = new VectorType(boolType0, vtype1->GetElementCount());
2645             destType1 = new VectorType(boolType1, vtype1->GetElementCount());
2646         } else {
2647             destType0 = boolType0;
2648             destType1 = boolType1;
2649         }
2650 
2651         arg0 = TypeConvertExpr(arg0, destType0, lOpString(op));
2652         arg1 = TypeConvertExpr(arg1, destType1, lOpString(op));
2653         if (arg0 == NULL || arg1 == NULL)
2654             return NULL;
2655         return this;
2656     }
2657     case Comma:
2658         return this;
2659     default:
2660         FATAL("logic error");
2661         return NULL;
2662     }
2663 }
2664 
GetLValueType() const2665 const Type *BinaryExpr::GetLValueType() const {
2666     const Type *t = GetType();
2667     if (CastType<PointerType>(t) != NULL) {
2668         // Are we doing something like (basePtr + offset)[...] = ...
2669         return t;
2670     } else {
2671         return NULL;
2672     }
2673 }
2674 
EstimateCost() const2675 int BinaryExpr::EstimateCost() const {
2676     if (llvm::dyn_cast<ConstExpr>(arg0) != NULL && llvm::dyn_cast<ConstExpr>(arg1) != NULL)
2677         return 0;
2678 
2679     return (op == Div || op == Mod) ? COST_COMPLEX_ARITH_OP : COST_SIMPLE_ARITH_LOGIC_OP;
2680 }
2681 
Print() const2682 void BinaryExpr::Print() const {
2683     if (!arg0 || !arg1 || !GetType())
2684         return;
2685 
2686     printf("[ %s ] (", GetType()->GetString().c_str());
2687     arg0->Print();
2688     printf(" %s ", lOpString(op));
2689     arg1->Print();
2690     printf(")");
2691     pos.Print();
2692 }
2693 
lGetBinaryExprStorageConstant(const Type * type,const BinaryExpr * bExpr,bool isStorageType)2694 static std::pair<llvm::Constant *, bool> lGetBinaryExprStorageConstant(const Type *type, const BinaryExpr *bExpr,
2695                                                                        bool isStorageType) {
2696 
2697     const BinaryExpr::Op op = bExpr->op;
2698     Expr *arg0 = bExpr->arg0;
2699     Expr *arg1 = bExpr->arg1;
2700 
2701     // Are we doing something like (basePtr + offset)[...] = ... for a Global
2702     // Variable
2703     if (!bExpr->GetLValueType())
2704         return std::pair<llvm::Constant *, bool>(NULL, false);
2705 
2706     // We are limiting cases to just addition and subtraction involving
2707     // pointer addresses
2708     // Case 1 : first argument is a pointer address.
2709     // In this case as long as the second argument is a constant value, we are fine
2710     // Case 2 : second argument is a pointer address.
2711     // In this case, it has to be an addition with first argument as
2712     // a constant value.
2713     if (!((op == BinaryExpr::Op::Add) || (op == BinaryExpr::Op::Sub)))
2714         return std::pair<llvm::Constant *, bool>(NULL, false);
2715     if (op == BinaryExpr::Op::Sub) {
2716         // Ignore cases where subtrahend is a PointerType
2717         // Eg. b - 5 is valid but 5 - b is not.
2718         if (CastType<PointerType>(arg1->GetType()))
2719             return std::pair<llvm::Constant *, bool>(NULL, false);
2720     }
2721 
2722     // 'isNotValidForMultiTargetGlobal' is required to let the caller know
2723     // that the llvm::constant value returned cannot be used in case of
2724     // multi-target compilation for initialization of globals. This is due
2725     // to different constant values for different targets, i.e. computation
2726     // involving sizeof() of varying types.
2727     // Since converting expr to constant can be a recursive process, we need
2728     // to ensure that if the flag is set by any expr in the chain, it's
2729     // reflected in the final return value.
2730     bool isNotValidForMultiTargetGlobal = false;
2731     if (const PointerType *pt0 = CastType<PointerType>(arg0->GetType())) {
2732         std::pair<llvm::Constant *, bool> c1Pair;
2733         if (isStorageType)
2734             c1Pair = arg0->GetStorageConstant(pt0);
2735         else
2736             c1Pair = arg0->GetConstant(pt0);
2737         llvm::Constant *c1 = c1Pair.first;
2738         isNotValidForMultiTargetGlobal = isNotValidForMultiTargetGlobal || c1Pair.second;
2739         ConstExpr *cExpr = llvm::dyn_cast<ConstExpr>(arg1);
2740         if ((cExpr == NULL) || (c1 == NULL))
2741             return std::pair<llvm::Constant *, bool>(NULL, false);
2742         std::pair<llvm::Constant *, bool> c2Pair;
2743         if (isStorageType)
2744             c2Pair = cExpr->GetStorageConstant(cExpr->GetType());
2745         else
2746             c2Pair = cExpr->GetConstant(cExpr->GetType());
2747         llvm::Constant *c2 = c2Pair.first;
2748         isNotValidForMultiTargetGlobal = isNotValidForMultiTargetGlobal || c2Pair.second;
2749         if (op == BinaryExpr::Op::Sub)
2750             c2 = llvm::ConstantExpr::getNeg(c2);
2751         llvm::Constant *c = llvm::ConstantExpr::getGetElementPtr(PTYPE(c1), c1, c2);
2752         return std::pair<llvm::Constant *, bool>(c, isNotValidForMultiTargetGlobal);
2753     } else if (const PointerType *pt1 = CastType<PointerType>(arg1->GetType())) {
2754         std::pair<llvm::Constant *, bool> c1Pair;
2755         if (isStorageType)
2756             c1Pair = arg1->GetStorageConstant(pt1);
2757         else
2758             c1Pair = arg1->GetConstant(pt1);
2759         llvm::Constant *c1 = c1Pair.first;
2760         isNotValidForMultiTargetGlobal = isNotValidForMultiTargetGlobal || c1Pair.second;
2761         ConstExpr *cExpr = llvm::dyn_cast<ConstExpr>(arg0);
2762         if ((cExpr == NULL) || (c1 == NULL))
2763             return std::pair<llvm::Constant *, bool>(NULL, false);
2764         std::pair<llvm::Constant *, bool> c2Pair;
2765         if (isStorageType)
2766             c2Pair = cExpr->GetStorageConstant(cExpr->GetType());
2767         else
2768             c2Pair = cExpr->GetConstant(cExpr->GetType());
2769         llvm::Constant *c2 = c2Pair.first;
2770         isNotValidForMultiTargetGlobal = isNotValidForMultiTargetGlobal || c2Pair.second;
2771         llvm::Constant *c = llvm::ConstantExpr::getGetElementPtr(PTYPE(c1), c1, c2);
2772         return std::pair<llvm::Constant *, bool>(c, isNotValidForMultiTargetGlobal);
2773     }
2774 
2775     return std::pair<llvm::Constant *, bool>(NULL, false);
2776 }
2777 
GetStorageConstant(const Type * type) const2778 std::pair<llvm::Constant *, bool> BinaryExpr::GetStorageConstant(const Type *type) const {
2779     return lGetBinaryExprStorageConstant(type, this, true);
2780 }
2781 
GetConstant(const Type * type) const2782 std::pair<llvm::Constant *, bool> BinaryExpr::GetConstant(const Type *type) const {
2783 
2784     return lGetBinaryExprStorageConstant(type, this, false);
2785 }
2786 ///////////////////////////////////////////////////////////////////////////
2787 // AssignExpr
2788 
lOpString(AssignExpr::Op op)2789 static const char *lOpString(AssignExpr::Op op) {
2790     switch (op) {
2791     case AssignExpr::Assign:
2792         return "assignment operator";
2793     case AssignExpr::MulAssign:
2794         return "*=";
2795     case AssignExpr::DivAssign:
2796         return "/=";
2797     case AssignExpr::ModAssign:
2798         return "%%=";
2799     case AssignExpr::AddAssign:
2800         return "+=";
2801     case AssignExpr::SubAssign:
2802         return "-=";
2803     case AssignExpr::ShlAssign:
2804         return "<<=";
2805     case AssignExpr::ShrAssign:
2806         return ">>=";
2807     case AssignExpr::AndAssign:
2808         return "&=";
2809     case AssignExpr::XorAssign:
2810         return "^=";
2811     case AssignExpr::OrAssign:
2812         return "|=";
2813     default:
2814         FATAL("Missing op in lOpstring");
2815         return "";
2816     }
2817 }
2818 
2819 /** Emit code to do an "assignment + operation" operator, e.g. "+=".
2820  */
lEmitOpAssign(AssignExpr::Op op,Expr * arg0,Expr * arg1,const Type * type,Symbol * baseSym,SourcePos pos,FunctionEmitContext * ctx)2821 static llvm::Value *lEmitOpAssign(AssignExpr::Op op, Expr *arg0, Expr *arg1, const Type *type, Symbol *baseSym,
2822                                   SourcePos pos, FunctionEmitContext *ctx) {
2823     llvm::Value *lv = arg0->GetLValue(ctx);
2824     if (!lv) {
2825         // FIXME: I think this test is unnecessary and that this case
2826         // should be caught during typechecking
2827         Error(pos, "Can't assign to left-hand side of expression.");
2828         return NULL;
2829     }
2830     const Type *lvalueType = arg0->GetLValueType();
2831     const Type *resultType = arg0->GetType();
2832     if (lvalueType == NULL || resultType == NULL)
2833         return NULL;
2834 
2835     // Get the value on the right-hand side of the assignment+operation
2836     // operator and load the current value on the left-hand side.
2837     llvm::Value *rvalue = arg1->GetValue(ctx);
2838     llvm::Value *mask = lMaskForSymbol(baseSym, ctx);
2839     ctx->SetDebugPos(arg0->pos);
2840     llvm::Value *oldLHS = ctx->LoadInst(lv, mask, lvalueType);
2841     ctx->SetDebugPos(pos);
2842 
2843     // Map the operator to the corresponding BinaryExpr::Op operator
2844     BinaryExpr::Op basicop;
2845     switch (op) {
2846     case AssignExpr::MulAssign:
2847         basicop = BinaryExpr::Mul;
2848         break;
2849     case AssignExpr::DivAssign:
2850         basicop = BinaryExpr::Div;
2851         break;
2852     case AssignExpr::ModAssign:
2853         basicop = BinaryExpr::Mod;
2854         break;
2855     case AssignExpr::AddAssign:
2856         basicop = BinaryExpr::Add;
2857         break;
2858     case AssignExpr::SubAssign:
2859         basicop = BinaryExpr::Sub;
2860         break;
2861     case AssignExpr::ShlAssign:
2862         basicop = BinaryExpr::Shl;
2863         break;
2864     case AssignExpr::ShrAssign:
2865         basicop = BinaryExpr::Shr;
2866         break;
2867     case AssignExpr::AndAssign:
2868         basicop = BinaryExpr::BitAnd;
2869         break;
2870     case AssignExpr::XorAssign:
2871         basicop = BinaryExpr::BitXor;
2872         break;
2873     case AssignExpr::OrAssign:
2874         basicop = BinaryExpr::BitOr;
2875         break;
2876     default:
2877         FATAL("logic error in lEmitOpAssign()");
2878         return NULL;
2879     }
2880 
2881     // Emit the code to compute the new value
2882     llvm::Value *newValue = NULL;
2883     switch (op) {
2884     case AssignExpr::MulAssign:
2885     case AssignExpr::DivAssign:
2886     case AssignExpr::ModAssign:
2887     case AssignExpr::AddAssign:
2888     case AssignExpr::SubAssign:
2889         newValue = lEmitBinaryArith(basicop, oldLHS, rvalue, type, arg1->GetType(), ctx, pos);
2890         break;
2891     case AssignExpr::ShlAssign:
2892     case AssignExpr::ShrAssign:
2893     case AssignExpr::AndAssign:
2894     case AssignExpr::XorAssign:
2895     case AssignExpr::OrAssign:
2896         newValue = lEmitBinaryBitOp(basicop, oldLHS, rvalue, arg0->GetType()->IsUnsignedType(), ctx);
2897         break;
2898     default:
2899         FATAL("logic error in lEmitOpAssign");
2900         return NULL;
2901     }
2902 
2903     // And store the result back to the lvalue.
2904     ctx->SetDebugPos(arg0->pos);
2905     lStoreAssignResult(newValue, lv, resultType, lvalueType, ctx, baseSym);
2906 
2907     return newValue;
2908 }
2909 
AssignExpr(AssignExpr::Op o,Expr * a,Expr * b,SourcePos p)2910 AssignExpr::AssignExpr(AssignExpr::Op o, Expr *a, Expr *b, SourcePos p) : Expr(p, AssignExprID), op(o) {
2911     lvalue = a;
2912     rvalue = b;
2913 }
2914 
GetValue(FunctionEmitContext * ctx) const2915 llvm::Value *AssignExpr::GetValue(FunctionEmitContext *ctx) const {
2916     const Type *type = NULL;
2917     if (lvalue == NULL || rvalue == NULL || (type = GetType()) == NULL)
2918         return NULL;
2919     ctx->SetDebugPos(pos);
2920 
2921     Symbol *baseSym = lvalue->GetBaseSymbol();
2922 
2923     switch (op) {
2924     case Assign: {
2925         llvm::Value *ptr = lvalue->GetLValue(ctx);
2926         if (ptr == NULL) {
2927             Error(lvalue->pos, "Left hand side of assignment expression can't "
2928                                "be assigned to.");
2929             return NULL;
2930         }
2931         const Type *ptrType = lvalue->GetLValueType();
2932         const Type *valueType = rvalue->GetType();
2933         if (ptrType == NULL || valueType == NULL) {
2934             AssertPos(pos, m->errorCount > 0);
2935             return NULL;
2936         }
2937 
2938         llvm::Value *value = rvalue->GetValue(ctx);
2939         if (value == NULL) {
2940             AssertPos(pos, m->errorCount > 0);
2941             return NULL;
2942         }
2943 
2944         ctx->SetDebugPos(lvalue->pos);
2945 
2946         lStoreAssignResult(value, ptr, valueType, ptrType, ctx, baseSym);
2947 
2948         return value;
2949     }
2950     case MulAssign:
2951     case DivAssign:
2952     case ModAssign:
2953     case AddAssign:
2954     case SubAssign:
2955     case ShlAssign:
2956     case ShrAssign:
2957     case AndAssign:
2958     case XorAssign:
2959     case OrAssign: {
2960         // This should be caught during type checking
2961         AssertPos(pos, !CastType<ArrayType>(type) && !CastType<StructType>(type));
2962         return lEmitOpAssign(op, lvalue, rvalue, type, baseSym, pos, ctx);
2963     }
2964     default:
2965         FATAL("logic error in AssignExpr::GetValue()");
2966         return NULL;
2967     }
2968 }
2969 
Optimize()2970 Expr *AssignExpr::Optimize() {
2971     if (lvalue == NULL || rvalue == NULL)
2972         return NULL;
2973     return this;
2974 }
2975 
GetType() const2976 const Type *AssignExpr::GetType() const { return lvalue ? lvalue->GetType() : NULL; }
2977 
2978 /** Recursively checks a structure type to see if it (or any struct type
2979     that it holds) has a const-qualified member. */
lCheckForConstStructMember(SourcePos pos,const StructType * structType,const StructType * initialType)2980 static bool lCheckForConstStructMember(SourcePos pos, const StructType *structType, const StructType *initialType) {
2981     for (int i = 0; i < structType->GetElementCount(); ++i) {
2982         const Type *t = structType->GetElementType(i);
2983         if (t->IsConstType()) {
2984             if (structType == initialType)
2985                 Error(pos,
2986                       "Illegal to assign to type \"%s\" due to element "
2987                       "\"%s\" with type \"%s\".",
2988                       structType->GetString().c_str(), structType->GetElementName(i).c_str(), t->GetString().c_str());
2989             else
2990                 Error(pos,
2991                       "Illegal to assign to type \"%s\" in type \"%s\" "
2992                       "due to element \"%s\" with type \"%s\".",
2993                       structType->GetString().c_str(), initialType->GetString().c_str(),
2994                       structType->GetElementName(i).c_str(), t->GetString().c_str());
2995             return true;
2996         }
2997 
2998         const StructType *st = CastType<StructType>(t);
2999         if (st != NULL && lCheckForConstStructMember(pos, st, initialType))
3000             return true;
3001     }
3002     return false;
3003 }
3004 
TypeCheck()3005 Expr *AssignExpr::TypeCheck() {
3006     if (lvalue == NULL || rvalue == NULL)
3007         return NULL;
3008 
3009     bool lvalueIsReference = CastType<ReferenceType>(lvalue->GetType()) != NULL;
3010     if (lvalueIsReference)
3011         lvalue = new RefDerefExpr(lvalue, lvalue->pos);
3012 
3013     if (PossiblyResolveFunctionOverloads(rvalue, lvalue->GetType()) == false) {
3014         Error(pos, "Unable to find overloaded function for function "
3015                    "pointer assignment.");
3016         return NULL;
3017     }
3018 
3019     const Type *lhsType = lvalue->GetType();
3020     if (lhsType == NULL) {
3021         AssertPos(pos, m->errorCount > 0);
3022         return NULL;
3023     }
3024 
3025     if (lhsType->IsConstType()) {
3026         Error(lvalue->pos,
3027               "Can't assign to type \"%s\" on left-hand side of "
3028               "expression.",
3029               lhsType->GetString().c_str());
3030         return NULL;
3031     }
3032 
3033     if (CastType<PointerType>(lhsType) != NULL) {
3034         if (op == AddAssign || op == SubAssign) {
3035             if (PointerType::IsVoidPointer(lhsType)) {
3036                 Error(pos,
3037                       "Illegal to perform pointer arithmetic on \"%s\" "
3038                       "type.",
3039                       lhsType->GetString().c_str());
3040                 return NULL;
3041             }
3042 
3043             const Type *deltaType = g->target->is32Bit() ? AtomicType::UniformInt32 : AtomicType::UniformInt64;
3044             if (lhsType->IsVaryingType())
3045                 deltaType = deltaType->GetAsVaryingType();
3046             rvalue = TypeConvertExpr(rvalue, deltaType, lOpString(op));
3047         } else if (op == Assign)
3048             rvalue = TypeConvertExpr(rvalue, lhsType, "assignment");
3049         else {
3050             Error(lvalue->pos, "Assignment operator \"%s\" is illegal with pointer types.", lOpString(op));
3051             return NULL;
3052         }
3053     } else if (CastType<ArrayType>(lhsType) != NULL) {
3054         Error(lvalue->pos, "Illegal to assign to array type \"%s\".", lhsType->GetString().c_str());
3055         return NULL;
3056     } else
3057         rvalue = TypeConvertExpr(rvalue, lhsType, lOpString(op));
3058 
3059     if (rvalue == NULL)
3060         return NULL;
3061 
3062     if (lhsType->IsFloatType() == true &&
3063         (op == ShlAssign || op == ShrAssign || op == AndAssign || op == XorAssign || op == OrAssign)) {
3064         Error(pos,
3065               "Illegal to use %s operator with floating-point "
3066               "operands.",
3067               lOpString(op));
3068         return NULL;
3069     }
3070 
3071     const StructType *st = CastType<StructType>(lhsType);
3072     if (st != NULL) {
3073         // Make sure we're not assigning to a struct that has a constant member
3074         if (lCheckForConstStructMember(pos, st, st))
3075             return NULL;
3076 
3077         if (op != Assign) {
3078             Error(lvalue->pos,
3079                   "Assignment operator \"%s\" is illegal with struct "
3080                   "type \"%s\".",
3081                   lOpString(op), st->GetString().c_str());
3082             return NULL;
3083         }
3084     }
3085     return this;
3086 }
3087 
EstimateCost() const3088 int AssignExpr::EstimateCost() const {
3089     if (op == Assign)
3090         return COST_ASSIGN;
3091     if (op == DivAssign || op == ModAssign)
3092         return COST_ASSIGN + COST_COMPLEX_ARITH_OP;
3093     else
3094         return COST_ASSIGN + COST_SIMPLE_ARITH_LOGIC_OP;
3095 }
3096 
Print() const3097 void AssignExpr::Print() const {
3098     if (!lvalue || !rvalue || !GetType())
3099         return;
3100 
3101     printf("[%s] assign (", GetType()->GetString().c_str());
3102     lvalue->Print();
3103     printf(" %s ", lOpString(op));
3104     rvalue->Print();
3105     printf(")");
3106     pos.Print();
3107 }
3108 
3109 ///////////////////////////////////////////////////////////////////////////
3110 // SelectExpr
3111 
SelectExpr(Expr * t,Expr * e1,Expr * e2,SourcePos p)3112 SelectExpr::SelectExpr(Expr *t, Expr *e1, Expr *e2, SourcePos p) : Expr(p, SelectExprID) {
3113     test = t;
3114     expr1 = e1;
3115     expr2 = e2;
3116 }
3117 
3118 /** Emit code to select between two varying values based on a varying test
3119     value.
3120  */
lEmitVaryingSelect(FunctionEmitContext * ctx,llvm::Value * test,llvm::Value * expr1,llvm::Value * expr2,const Type * type)3121 static llvm::Value *lEmitVaryingSelect(FunctionEmitContext *ctx, llvm::Value *test, llvm::Value *expr1,
3122                                        llvm::Value *expr2, const Type *type) {
3123 
3124     llvm::Value *resultPtr = ctx->AllocaInst(type, "selectexpr_tmp");
3125     Assert(resultPtr != NULL);
3126     // Don't need to worry about masking here
3127     ctx->StoreInst(expr2, resultPtr, type, type->IsUniformType());
3128     // Use masking to conditionally store the expr1 values
3129     Assert(resultPtr->getType() == PointerType::GetUniform(type)->LLVMStorageType(g->ctx));
3130     ctx->StoreInst(expr1, resultPtr, test, type, PointerType::GetUniform(type));
3131     return ctx->LoadInst(resultPtr, type, "selectexpr_final");
3132 }
3133 
lEmitSelectExprCode(FunctionEmitContext * ctx,llvm::Value * testVal,llvm::Value * oldMask,llvm::Value * fullMask,Expr * expr,llvm::Value * exprPtr)3134 static void lEmitSelectExprCode(FunctionEmitContext *ctx, llvm::Value *testVal, llvm::Value *oldMask,
3135                                 llvm::Value *fullMask, Expr *expr, llvm::Value *exprPtr) {
3136     llvm::BasicBlock *bbEval = ctx->CreateBasicBlock("select_eval_expr", ctx->GetCurrentBasicBlock());
3137     llvm::BasicBlock *bbDone = ctx->CreateBasicBlock("select_done", bbEval);
3138 
3139     // Check to see if the test was true for any of the currently executing
3140     // program instances.
3141     llvm::Value *testAndFullMask = ctx->BinaryOperator(llvm::Instruction::And, testVal, fullMask, "test&mask");
3142     llvm::Value *anyOn = ctx->Any(testAndFullMask);
3143     ctx->BranchInst(bbEval, bbDone, anyOn);
3144 
3145     ctx->SetCurrentBasicBlock(bbEval);
3146     llvm::Value *testAndMask = ctx->BinaryOperator(llvm::Instruction::And, testVal, oldMask, "test&mask");
3147     ctx->SetInternalMask(testAndMask);
3148     llvm::Value *exprVal = expr->GetValue(ctx);
3149     ctx->StoreInst(exprVal, exprPtr, expr->GetType(), expr->GetType()->IsUniformType());
3150     ctx->BranchInst(bbDone);
3151 
3152     ctx->SetCurrentBasicBlock(bbDone);
3153 }
3154 
HasAmbiguousVariability(std::vector<const Expr * > & warn) const3155 bool SelectExpr::HasAmbiguousVariability(std::vector<const Expr *> &warn) const {
3156     bool isExpr1Amb = false;
3157     bool isExpr2Amb = false;
3158     if (expr1 != NULL) {
3159         const Type *type1 = expr1->GetType();
3160         if (expr1->HasAmbiguousVariability(warn)) {
3161             isExpr1Amb = true;
3162         } else if ((type1 != NULL) && (type1->IsVaryingType())) {
3163             // If either expr is varying, then the expression is un-ambiguously varying.
3164             return false;
3165         }
3166     }
3167     if (expr2 != NULL) {
3168         const Type *type2 = expr2->GetType();
3169         if (expr2->HasAmbiguousVariability(warn)) {
3170             isExpr2Amb = true;
3171         } else if ((type2 != NULL) && (type2->IsVaryingType())) {
3172             // If either arg is varying, then the expression is un-ambiguously varying.
3173             return false;
3174         }
3175     }
3176     if (isExpr1Amb || isExpr2Amb) {
3177         return true;
3178     }
3179 
3180     return false;
3181 }
3182 
GetValue(FunctionEmitContext * ctx) const3183 llvm::Value *SelectExpr::GetValue(FunctionEmitContext *ctx) const {
3184     if (!expr1 || !expr2 || !test)
3185         return NULL;
3186 
3187     ctx->SetDebugPos(pos);
3188 
3189     const Type *testType = test->GetType()->GetAsNonConstType();
3190     // This should be taken care of during typechecking
3191     AssertPos(pos, Type::Equal(testType->GetBaseType(), AtomicType::UniformBool) ||
3192                        Type::Equal(testType->GetBaseType(), AtomicType::VaryingBool));
3193 
3194     const Type *type = expr1->GetType();
3195 
3196     if (Type::Equal(testType, AtomicType::UniformBool)) {
3197         // Simple case of a single uniform bool test expression; we just
3198         // want one of the two expressions.  In this case, we can be
3199         // careful to evaluate just the one of the expressions that we need
3200         // the value of so that if the other one has side-effects or
3201         // accesses invalid memory, it doesn't execute.
3202         llvm::Value *testVal = test->GetValue(ctx);
3203         llvm::BasicBlock *testTrue = ctx->CreateBasicBlock("select_true", ctx->GetCurrentBasicBlock());
3204         llvm::BasicBlock *testFalse = ctx->CreateBasicBlock("select_false", testTrue);
3205         llvm::BasicBlock *testDone = ctx->CreateBasicBlock("select_done", testFalse);
3206         ctx->BranchInst(testTrue, testFalse, testVal);
3207 
3208         ctx->SetCurrentBasicBlock(testTrue);
3209         llvm::Value *expr1Val = expr1->GetValue(ctx);
3210         // Note that truePred won't be necessarily equal to testTrue, in
3211         // case the expr1->GetValue() call changes the current basic block.
3212         llvm::BasicBlock *truePred = ctx->GetCurrentBasicBlock();
3213         ctx->BranchInst(testDone);
3214 
3215         ctx->SetCurrentBasicBlock(testFalse);
3216         llvm::Value *expr2Val = expr2->GetValue(ctx);
3217         // See comment above truePred for why we can't just assume we're in
3218         // the testFalse basic block here.
3219         llvm::BasicBlock *falsePred = ctx->GetCurrentBasicBlock();
3220         ctx->BranchInst(testDone);
3221 
3222         ctx->SetCurrentBasicBlock(testDone);
3223         llvm::PHINode *ret = ctx->PhiNode(expr1Val->getType(), 2, "select");
3224         ret->addIncoming(expr1Val, truePred);
3225         ret->addIncoming(expr2Val, falsePred);
3226         return ret;
3227     } else if (CastType<VectorType>(testType) == NULL) {
3228         // the test is a varying bool type
3229         llvm::Value *testVal = test->GetValue(ctx);
3230         AssertPos(pos, testVal->getType() == LLVMTypes::MaskType);
3231         llvm::Value *oldMask = ctx->GetInternalMask();
3232         llvm::Value *fullMask = ctx->GetFullMask();
3233 
3234         // We don't want to incur the overhead for short-circuit evaluation
3235         // for expressions that are both computationally simple and safe to
3236         // run with an "all off" mask.
3237         int threshold = g->target->isGenXTarget() ? PREDICATE_SAFE_SHORT_CIRC_GENX_STATEMENT_COST
3238                                                   : PREDICATE_SAFE_IF_STATEMENT_COST;
3239         bool shortCircuit1 = (::EstimateCost(expr1) > threshold || SafeToRunWithMaskAllOff(expr1) == false);
3240         bool shortCircuit2 = (::EstimateCost(expr2) > threshold || SafeToRunWithMaskAllOff(expr2) == false);
3241 
3242         Debug(expr1->pos, "%sshort circuiting evaluation for select expr", shortCircuit1 ? "" : "Not ");
3243         Debug(expr2->pos, "%sshort circuiting evaluation for select expr", shortCircuit2 ? "" : "Not ");
3244 
3245         // Temporary storage to store the values computed for each
3246         // expression, if any.  (These stay as uninitialized memory if we
3247         // short circuit around the corresponding expression.)
3248         llvm::Value *expr1Ptr = ctx->AllocaInst(expr1->GetType());
3249         llvm::Value *expr2Ptr = ctx->AllocaInst(expr1->GetType());
3250 
3251         if (shortCircuit1)
3252             lEmitSelectExprCode(ctx, testVal, oldMask, fullMask, expr1, expr1Ptr);
3253         else {
3254             ctx->SetInternalMaskAnd(oldMask, testVal);
3255             llvm::Value *expr1Val = expr1->GetValue(ctx);
3256             ctx->StoreInst(expr1Val, expr1Ptr, expr1->GetType(), expr1->GetType()->IsUniformType());
3257         }
3258 
3259         if (shortCircuit2) {
3260             llvm::Value *notTest = ctx->NotOperator(testVal);
3261             lEmitSelectExprCode(ctx, notTest, oldMask, fullMask, expr2, expr2Ptr);
3262         } else {
3263             ctx->SetInternalMaskAndNot(oldMask, testVal);
3264             llvm::Value *expr2Val = expr2->GetValue(ctx);
3265             ctx->StoreInst(expr2Val, expr2Ptr, expr2->GetType(), expr2->GetType()->IsUniformType());
3266         }
3267 
3268         ctx->SetInternalMask(oldMask);
3269         llvm::Value *expr1Val = ctx->LoadInst(expr1Ptr, expr1->GetType());
3270         llvm::Value *expr2Val = ctx->LoadInst(expr2Ptr, expr2->GetType());
3271         return lEmitVaryingSelect(ctx, testVal, expr1Val, expr2Val, type);
3272     } else {
3273         // FIXME? Short-circuiting doesn't work in the case of
3274         // vector-valued test expressions.  (We could also just prohibit
3275         // these and place the issue in the user's hands...)
3276         llvm::Value *testVal = test->GetValue(ctx);
3277         llvm::Value *expr1Val = expr1->GetValue(ctx);
3278         llvm::Value *expr2Val = expr2->GetValue(ctx);
3279 
3280         ctx->SetDebugPos(pos);
3281         const VectorType *vt = CastType<VectorType>(type);
3282         // Things that typechecking should have caught
3283         AssertPos(pos, vt != NULL);
3284         AssertPos(pos, CastType<VectorType>(testType) != NULL &&
3285                            (CastType<VectorType>(testType)->GetElementCount() == vt->GetElementCount()));
3286 
3287         // Do an element-wise select
3288         llvm::Value *result = llvm::UndefValue::get(type->LLVMType(g->ctx));
3289         for (int i = 0; i < vt->GetElementCount(); ++i) {
3290             llvm::Value *ti = ctx->ExtractInst(testVal, i);
3291             llvm::Value *e1i = ctx->ExtractInst(expr1Val, i);
3292             llvm::Value *e2i = ctx->ExtractInst(expr2Val, i);
3293             llvm::Value *sel = NULL;
3294             if (testType->IsUniformType()) {
3295                 // Extracting uniform vector bool to uniform bool require
3296                 // switching from i8 -> i1
3297                 ti = ctx->SwitchBoolSize(ti, LLVMTypes::BoolType);
3298                 sel = ctx->SelectInst(ti, e1i, e2i);
3299             } else {
3300                 // Extracting varying vector bools to varying bools require
3301                 // switching from <WIDTH x i8> -> <WIDTH x MaskType>
3302                 ti = ctx->SwitchBoolSize(ti, LLVMTypes::BoolVectorType);
3303                 sel = lEmitVaryingSelect(ctx, ti, e1i, e2i, vt->GetElementType());
3304             }
3305             result = ctx->InsertInst(result, sel, i);
3306         }
3307         return result;
3308     }
3309 }
3310 
GetType() const3311 const Type *SelectExpr::GetType() const {
3312     if (!test || !expr1 || !expr2)
3313         return NULL;
3314 
3315     const Type *testType = test->GetType();
3316     const Type *expr1Type = expr1->GetType();
3317     const Type *expr2Type = expr2->GetType();
3318 
3319     if (!testType || !expr1Type || !expr2Type)
3320         return NULL;
3321 
3322     bool becomesVarying = (testType->IsVaryingType() || expr1Type->IsVaryingType() || expr2Type->IsVaryingType());
3323     // if expr1 and expr2 have different vector sizes, typechecking should fail...
3324     int testVecSize = CastType<VectorType>(testType) != NULL ? CastType<VectorType>(testType)->GetElementCount() : 0;
3325     int expr1VecSize = CastType<VectorType>(expr1Type) != NULL ? CastType<VectorType>(expr1Type)->GetElementCount() : 0;
3326     AssertPos(pos, !(testVecSize != 0 && expr1VecSize != 0 && testVecSize != expr1VecSize));
3327 
3328     int vectorSize = std::max(testVecSize, expr1VecSize);
3329     return Type::MoreGeneralType(expr1Type, expr2Type, Union(expr1->pos, expr2->pos), "select expression",
3330                                  becomesVarying, vectorSize);
3331 }
3332 
3333 template <typename T>
lConstFoldSelect(const bool bv[],ConstExpr * constExpr1,ConstExpr * constExpr2,const Type * exprType,SourcePos pos)3334 Expr *lConstFoldSelect(const bool bv[], ConstExpr *constExpr1, ConstExpr *constExpr2, const Type *exprType,
3335                        SourcePos pos) {
3336     T v1[ISPC_MAX_NVEC], v2[ISPC_MAX_NVEC];
3337     T result[ISPC_MAX_NVEC];
3338     int count = constExpr1->GetValues(v1);
3339     constExpr2->GetValues(v2);
3340     for (int i = 0; i < count; ++i)
3341         result[i] = bv[i] ? v1[i] : v2[i];
3342     return new ConstExpr(exprType, result, pos);
3343 }
3344 
Optimize()3345 Expr *SelectExpr::Optimize() {
3346     if (test == NULL || expr1 == NULL || expr2 == NULL)
3347         return NULL;
3348 
3349     ConstExpr *constTest = llvm::dyn_cast<ConstExpr>(test);
3350     if (constTest == NULL)
3351         return this;
3352 
3353     // The test is a constant; see if we can resolve to one of the
3354     // expressions..
3355     bool bv[ISPC_MAX_NVEC];
3356     int count = constTest->GetValues(bv);
3357     if (count == 1)
3358         // Uniform test value; return the corresponding expression
3359         return (bv[0] == true) ? expr1 : expr2;
3360     else {
3361         // Varying test: see if all of the values are the same; if so, then
3362         // return the corresponding expression
3363         bool first = bv[0];
3364         bool mismatch = false;
3365         for (int i = 0; i < count; ++i)
3366             if (bv[i] != first) {
3367                 mismatch = true;
3368                 break;
3369             }
3370         if (mismatch == false)
3371             return (bv[0] == true) ? expr1 : expr2;
3372 
3373         // Last chance: see if the two expressions are constants; if so,
3374         // then we can do an element-wise selection based on the constant
3375         // condition..
3376         ConstExpr *constExpr1 = llvm::dyn_cast<ConstExpr>(expr1);
3377         ConstExpr *constExpr2 = llvm::dyn_cast<ConstExpr>(expr2);
3378         if (constExpr1 == NULL || constExpr2 == NULL)
3379             return this;
3380 
3381         AssertPos(pos, Type::Equal(constExpr1->GetType(), constExpr2->GetType()));
3382         const Type *exprType = constExpr1->GetType()->GetAsNonConstType();
3383         AssertPos(pos, exprType->IsVaryingType());
3384 
3385         if (Type::Equal(exprType, AtomicType::VaryingInt8)) {
3386             return lConstFoldSelect<int8_t>(bv, constExpr1, constExpr2, exprType, pos);
3387         } else if (Type::Equal(exprType, AtomicType::VaryingUInt8)) {
3388             return lConstFoldSelect<uint8_t>(bv, constExpr1, constExpr2, exprType, pos);
3389         } else if (Type::Equal(exprType, AtomicType::VaryingInt16)) {
3390             return lConstFoldSelect<int16_t>(bv, constExpr1, constExpr2, exprType, pos);
3391         } else if (Type::Equal(exprType, AtomicType::VaryingUInt16)) {
3392             return lConstFoldSelect<uint16_t>(bv, constExpr1, constExpr2, exprType, pos);
3393         } else if (Type::Equal(exprType, AtomicType::VaryingInt32)) {
3394             return lConstFoldSelect<int32_t>(bv, constExpr1, constExpr2, exprType, pos);
3395         } else if (Type::Equal(exprType, AtomicType::VaryingUInt32)) {
3396             return lConstFoldSelect<uint32_t>(bv, constExpr1, constExpr2, exprType, pos);
3397         } else if (Type::Equal(exprType, AtomicType::VaryingInt64)) {
3398             return lConstFoldSelect<int64_t>(bv, constExpr1, constExpr2, exprType, pos);
3399         } else if (Type::Equal(exprType, AtomicType::VaryingUInt64)) {
3400             return lConstFoldSelect<uint64_t>(bv, constExpr1, constExpr2, exprType, pos);
3401         } else if (Type::Equal(exprType, AtomicType::VaryingFloat)) {
3402             return lConstFoldSelect<float>(bv, constExpr1, constExpr2, exprType, pos);
3403         } else if (Type::Equal(exprType, AtomicType::VaryingDouble)) {
3404             return lConstFoldSelect<bool>(bv, constExpr1, constExpr2, exprType, pos);
3405         } else if (Type::Equal(exprType, AtomicType::VaryingBool)) {
3406             return lConstFoldSelect<double>(bv, constExpr1, constExpr2, exprType, pos);
3407         }
3408 
3409         return this;
3410     }
3411 }
3412 
TypeCheck()3413 Expr *SelectExpr::TypeCheck() {
3414     if (test == NULL || expr1 == NULL || expr2 == NULL)
3415         return NULL;
3416 
3417     const Type *type1 = expr1->GetType(), *type2 = expr2->GetType();
3418     if (!type1 || !type2)
3419         return NULL;
3420 
3421     if (const ArrayType *at1 = CastType<ArrayType>(type1)) {
3422         expr1 = TypeConvertExpr(expr1, PointerType::GetUniform(at1->GetBaseType()), "select");
3423         if (expr1 == NULL)
3424             return NULL;
3425         type1 = expr1->GetType();
3426     }
3427     if (const ArrayType *at2 = CastType<ArrayType>(type2)) {
3428         expr2 = TypeConvertExpr(expr2, PointerType::GetUniform(at2->GetBaseType()), "select");
3429         if (expr2 == NULL)
3430             return NULL;
3431         type2 = expr2->GetType();
3432     }
3433 
3434     const Type *testType = test->GetType();
3435     if (testType == NULL)
3436         return NULL;
3437     test = TypeConvertExpr(test, lMatchingBoolType(testType), "select");
3438     if (test == NULL)
3439         return NULL;
3440     testType = test->GetType();
3441 
3442     int testVecSize = CastType<VectorType>(testType) ? CastType<VectorType>(testType)->GetElementCount() : 0;
3443     const Type *promotedType = Type::MoreGeneralType(type1, type2, Union(expr1->pos, expr2->pos), "select expression",
3444                                                      testType->IsVaryingType(), testVecSize);
3445 
3446     // If the promoted type is a ReferenceType, the expression type will be
3447     // the reference target type since SelectExpr is always a rvalue.
3448     if (CastType<ReferenceType>(promotedType) != NULL)
3449         promotedType = promotedType->GetReferenceTarget();
3450 
3451     if (promotedType == NULL)
3452         return NULL;
3453 
3454     expr1 = TypeConvertExpr(expr1, promotedType, "select");
3455     expr2 = TypeConvertExpr(expr2, promotedType, "select");
3456     if (expr1 == NULL || expr2 == NULL)
3457         return NULL;
3458 
3459     return this;
3460 }
3461 
EstimateCost() const3462 int SelectExpr::EstimateCost() const { return COST_SELECT; }
3463 
Print() const3464 void SelectExpr::Print() const {
3465     if (!test || !expr1 || !expr2 || !GetType())
3466         return;
3467 
3468     printf("[%s] (", GetType()->GetString().c_str());
3469     test->Print();
3470     printf(" ? ");
3471     expr1->Print();
3472     printf(" : ");
3473     expr2->Print();
3474     printf(")");
3475     pos.Print();
3476 }
3477 
3478 ///////////////////////////////////////////////////////////////////////////
3479 // FunctionCallExpr
3480 
FunctionCallExpr(Expr * f,ExprList * a,SourcePos p,bool il,Expr * lce[3])3481 FunctionCallExpr::FunctionCallExpr(Expr *f, ExprList *a, SourcePos p, bool il, Expr *lce[3])
3482     : Expr(p, FunctionCallExprID), isLaunch(il) {
3483     func = f;
3484     args = a;
3485     std::vector<const Expr *> warn;
3486     if (a->HasAmbiguousVariability(warn) == true) {
3487         for (auto w : warn) {
3488             const TypeCastExpr *tExpr = llvm::dyn_cast<TypeCastExpr>(w);
3489             tExpr->PrintAmbiguousVariability();
3490         }
3491     }
3492     if (lce != NULL) {
3493         launchCountExpr[0] = lce[0];
3494         launchCountExpr[1] = lce[1];
3495         launchCountExpr[2] = lce[2];
3496     } else
3497         launchCountExpr[0] = launchCountExpr[1] = launchCountExpr[2] = NULL;
3498 }
3499 
lGetFunctionType(Expr * func)3500 static const FunctionType *lGetFunctionType(Expr *func) {
3501     if (func == NULL)
3502         return NULL;
3503 
3504     const Type *type = func->GetType();
3505     if (type == NULL)
3506         return NULL;
3507 
3508     const FunctionType *ftype = CastType<FunctionType>(type);
3509     if (ftype == NULL) {
3510         // Not a regular function symbol--is it a function pointer?
3511         if (CastType<PointerType>(type) != NULL)
3512             ftype = CastType<FunctionType>(type->GetBaseType());
3513     }
3514     return ftype;
3515 }
3516 
GetValue(FunctionEmitContext * ctx) const3517 llvm::Value *FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const {
3518     if (func == NULL || args == NULL)
3519         return NULL;
3520 
3521     ctx->SetDebugPos(pos);
3522 
3523     llvm::Value *callee = func->GetValue(ctx);
3524 
3525     if (callee == NULL) {
3526         AssertPos(pos, m->errorCount > 0);
3527         return NULL;
3528     }
3529 
3530     const FunctionType *ft = lGetFunctionType(func);
3531     AssertPos(pos, ft != NULL);
3532     bool isVoidFunc = ft->GetReturnType()->IsVoidType();
3533 
3534     // Automatically convert function call args to references if needed.
3535     // FIXME: this should move to the TypeCheck() method... (but the
3536     // GetLValue call below needs a FunctionEmitContext, which is
3537     // problematic...)
3538     std::vector<Expr *> callargs = args->exprs;
3539 
3540     // Specifically, this can happen if there's an error earlier during
3541     // overload resolution.
3542     if ((int)callargs.size() > ft->GetNumParameters()) {
3543         AssertPos(pos, m->errorCount > 0);
3544         return NULL;
3545     }
3546 
3547     for (unsigned int i = 0; i < callargs.size(); ++i) {
3548         Expr *argExpr = callargs[i];
3549         if (argExpr == NULL)
3550             continue;
3551 
3552         const Type *paramType = ft->GetParameterType(i);
3553 
3554         const Type *argLValueType = argExpr->GetLValueType();
3555         if (argLValueType != NULL && CastType<PointerType>(argLValueType) != NULL && argLValueType->IsVaryingType() &&
3556             CastType<ReferenceType>(paramType) != NULL) {
3557             Error(argExpr->pos,
3558                   "Illegal to pass a \"varying\" lvalue to a "
3559                   "reference parameter of type \"%s\".",
3560                   paramType->GetString().c_str());
3561             return NULL;
3562         }
3563 
3564         // Do whatever type conversion is needed
3565         argExpr = TypeConvertExpr(argExpr, paramType, "function call argument");
3566         if (argExpr == NULL)
3567             return NULL;
3568         callargs[i] = argExpr;
3569     }
3570 
3571     // Fill in any default argument values needed.
3572     // FIXME: should we do this during type checking?
3573     for (int i = callargs.size(); i < ft->GetNumParameters(); ++i) {
3574         Expr *paramDefault = ft->GetParameterDefault(i);
3575         const Type *paramType = ft->GetParameterType(i);
3576         // FIXME: this type conv should happen when we create the function
3577         // type!
3578         Expr *d = TypeConvertExpr(paramDefault, paramType, "function call default argument");
3579         if (d == NULL)
3580             return NULL;
3581         callargs.push_back(d);
3582     }
3583 
3584     // Now evaluate the values of all of the parameters being passed.
3585     std::vector<llvm::Value *> argVals;
3586     for (unsigned int i = 0; i < callargs.size(); ++i) {
3587         Expr *argExpr = callargs[i];
3588         if (argExpr == NULL)
3589             // give up; we hit an error earlier
3590             return NULL;
3591 
3592         llvm::Value *argValue = argExpr->GetValue(ctx);
3593         if (argValue == NULL)
3594             // something went wrong in evaluating the argument's
3595             // expression, so give up on this
3596             return NULL;
3597 
3598         argVals.push_back(argValue);
3599     }
3600 
3601     llvm::Value *retVal = NULL;
3602     ctx->SetDebugPos(pos);
3603     if (ft->isTask) {
3604         AssertPos(pos, launchCountExpr[0] != NULL);
3605         llvm::Value *launchCount[3] = {launchCountExpr[0]->GetValue(ctx), launchCountExpr[1]->GetValue(ctx),
3606                                        launchCountExpr[2]->GetValue(ctx)};
3607 
3608         if (launchCount[0] != NULL)
3609             ctx->LaunchInst(callee, argVals, launchCount);
3610     } else
3611         retVal = ctx->CallInst(callee, ft, argVals, isVoidFunc ? "" : "calltmp");
3612 
3613     if (isVoidFunc)
3614         return NULL;
3615     else
3616         return retVal;
3617 }
3618 
GetLValue(FunctionEmitContext * ctx) const3619 llvm::Value *FunctionCallExpr::GetLValue(FunctionEmitContext *ctx) const {
3620     if (GetLValueType() != NULL) {
3621         return GetValue(ctx);
3622     } else {
3623         // Only be a valid LValue type if the function
3624         // returns a pointer or reference.
3625         return NULL;
3626     }
3627 }
3628 
FullResolveOverloads(Expr * func,ExprList * args,std::vector<const Type * > * argTypes,std::vector<bool> * argCouldBeNULL,std::vector<bool> * argIsConstant)3629 bool FullResolveOverloads(Expr *func, ExprList *args, std::vector<const Type *> *argTypes,
3630                           std::vector<bool> *argCouldBeNULL, std::vector<bool> *argIsConstant) {
3631     for (unsigned int i = 0; i < args->exprs.size(); ++i) {
3632         Expr *expr = args->exprs[i];
3633         if (expr == NULL)
3634             return false;
3635         const Type *t = expr->GetType();
3636         if (t == NULL)
3637             return false;
3638         argTypes->push_back(t);
3639         argCouldBeNULL->push_back(lIsAllIntZeros(expr) || llvm::dyn_cast<NullPointerExpr>(expr));
3640         argIsConstant->push_back(llvm::dyn_cast<ConstExpr>(expr) || llvm::dyn_cast<NullPointerExpr>(expr));
3641     }
3642     return true;
3643 }
3644 
GetType() const3645 const Type *FunctionCallExpr::GetType() const {
3646     std::vector<const Type *> argTypes;
3647     std::vector<bool> argCouldBeNULL, argIsConstant;
3648     if (FullResolveOverloads(func, args, &argTypes, &argCouldBeNULL, &argIsConstant) == true) {
3649         FunctionSymbolExpr *fse = llvm::dyn_cast<FunctionSymbolExpr>(func);
3650         if (fse != NULL) {
3651             fse->ResolveOverloads(args->pos, argTypes, &argCouldBeNULL, &argIsConstant);
3652         }
3653     }
3654     const FunctionType *ftype = lGetFunctionType(func);
3655     return ftype ? ftype->GetReturnType() : NULL;
3656 }
3657 
GetLValueType() const3658 const Type *FunctionCallExpr::GetLValueType() const {
3659     const FunctionType *ftype = lGetFunctionType(func);
3660     if (ftype && (ftype->GetReturnType()->IsPointerType() || ftype->GetReturnType()->IsReferenceType())) {
3661         return ftype->GetReturnType();
3662     } else {
3663         // Only be a valid LValue type if the function
3664         // returns a pointer or reference.
3665         return NULL;
3666     }
3667 }
3668 
Optimize()3669 Expr *FunctionCallExpr::Optimize() {
3670     if (func == NULL || args == NULL)
3671         return NULL;
3672     return this;
3673 }
3674 
TypeCheck()3675 Expr *FunctionCallExpr::TypeCheck() {
3676     if (func == NULL || args == NULL)
3677         return NULL;
3678 
3679     std::vector<const Type *> argTypes;
3680     std::vector<bool> argCouldBeNULL, argIsConstant;
3681 
3682     if (FullResolveOverloads(func, args, &argTypes, &argCouldBeNULL, &argIsConstant) == false) {
3683         return NULL;
3684     }
3685 
3686     FunctionSymbolExpr *fse = llvm::dyn_cast<FunctionSymbolExpr>(func);
3687     if (fse != NULL) {
3688         // Regular function call
3689         if (fse->ResolveOverloads(args->pos, argTypes, &argCouldBeNULL, &argIsConstant) == false)
3690             return NULL;
3691 
3692         func = ::TypeCheck(fse);
3693         if (func == NULL)
3694             return NULL;
3695 
3696         const FunctionType *ft = CastType<FunctionType>(func->GetType());
3697         if (ft == NULL) {
3698             const PointerType *pt = CastType<PointerType>(func->GetType());
3699             ft = (pt == NULL) ? NULL : CastType<FunctionType>(pt->GetBaseType());
3700         }
3701 
3702         if (ft == NULL) {
3703             Error(pos, "Valid function name must be used for function call.");
3704             return NULL;
3705         }
3706 
3707         if (ft->isTask) {
3708             if (!isLaunch)
3709                 Error(pos, "\"launch\" expression needed to call function "
3710                            "with \"task\" qualifier.");
3711             for (int k = 0; k < 3; k++) {
3712                 if (!launchCountExpr[k])
3713                     return NULL;
3714 
3715                 launchCountExpr[k] = TypeConvertExpr(launchCountExpr[k], AtomicType::UniformInt32, "task launch count");
3716                 if (launchCountExpr[k] == NULL)
3717                     return NULL;
3718             }
3719         } else {
3720             if (isLaunch) {
3721                 Error(pos, "\"launch\" expression illegal with non-\"task\"-"
3722                            "qualified function.");
3723                 return NULL;
3724             }
3725             AssertPos(pos, launchCountExpr[0] == NULL);
3726         }
3727     } else {
3728         // Call through a function pointer
3729         const Type *fptrType = func->GetType();
3730         if (fptrType == NULL)
3731             return NULL;
3732 
3733         // Make sure we do in fact have a function to call
3734         const FunctionType *funcType;
3735         if (CastType<PointerType>(fptrType) == NULL ||
3736             (funcType = CastType<FunctionType>(fptrType->GetBaseType())) == NULL) {
3737             Error(func->pos, "Must provide function name or function pointer for "
3738                              "function call expression.");
3739             return NULL;
3740         }
3741 
3742         // Make sure we don't have too many arguments for the function
3743         if ((int)argTypes.size() > funcType->GetNumParameters()) {
3744             Error(args->pos,
3745                   "Too many parameter values provided in "
3746                   "function call (%d provided, %d expected).",
3747                   (int)argTypes.size(), funcType->GetNumParameters());
3748             return NULL;
3749         }
3750         // It's ok to have too few arguments, as long as the function's
3751         // default parameter values have started by the time we run out
3752         // of arguments
3753         if ((int)argTypes.size() < funcType->GetNumParameters() &&
3754             funcType->GetParameterDefault(argTypes.size()) == NULL) {
3755             Error(args->pos,
3756                   "Too few parameter values provided in "
3757                   "function call (%d provided, %d expected).",
3758                   (int)argTypes.size(), funcType->GetNumParameters());
3759             return NULL;
3760         }
3761 
3762         // Now make sure they can all type convert to the corresponding
3763         // parameter types..
3764         for (int i = 0; i < (int)argTypes.size(); ++i) {
3765             if (i < funcType->GetNumParameters()) {
3766                 // make sure it can type convert
3767                 const Type *paramType = funcType->GetParameterType(i);
3768                 if (CanConvertTypes(argTypes[i], paramType) == false &&
3769                     !(argCouldBeNULL[i] == true && CastType<PointerType>(paramType) != NULL)) {
3770                     Error(args->exprs[i]->pos,
3771                           "Can't convert argument of "
3772                           "type \"%s\" to type \"%s\" for function call "
3773                           "argument.",
3774                           argTypes[i]->GetString().c_str(), paramType->GetString().c_str());
3775                     return NULL;
3776                 }
3777             } else
3778                 // Otherwise the parameter default saves us.  It should
3779                 // be there for sure, given the check right above the
3780                 // for loop.
3781                 AssertPos(pos, funcType->GetParameterDefault(i) != NULL);
3782         }
3783 
3784         if (fptrType->IsVaryingType()) {
3785             const Type *retType = funcType->GetReturnType();
3786             if (retType->IsVoidType() == false && retType->IsUniformType()) {
3787                 Error(pos,
3788                       "Illegal to call a varying function pointer that "
3789                       "points to a function with a uniform return type \"%s\".",
3790                       funcType->GetReturnType()->GetString().c_str());
3791                 return NULL;
3792             }
3793         }
3794     }
3795 
3796     if (func == NULL || args == NULL)
3797         return NULL;
3798     return this;
3799 }
3800 
EstimateCost() const3801 int FunctionCallExpr::EstimateCost() const {
3802     if (isLaunch)
3803         return COST_TASK_LAUNCH;
3804 
3805     const Type *type = func->GetType();
3806     if (type == NULL)
3807         return 0;
3808 
3809     const PointerType *pt = CastType<PointerType>(type);
3810     if (pt != NULL)
3811         type = type->GetBaseType();
3812 
3813     const FunctionType *ftype = CastType<FunctionType>(type);
3814     if (ftype != NULL && ftype->costOverride > -1)
3815         return ftype->costOverride;
3816 
3817     if (pt != NULL)
3818         return pt->IsUniformType() ? COST_FUNPTR_UNIFORM : COST_FUNPTR_VARYING;
3819     else
3820         return COST_FUNCALL;
3821 }
3822 
Print() const3823 void FunctionCallExpr::Print() const {
3824     if (!func || !args || !GetType())
3825         return;
3826 
3827     printf("[%s] funcall %s ", GetType()->GetString().c_str(), isLaunch ? "launch" : "");
3828     func->Print();
3829     printf(" args (");
3830     args->Print();
3831     printf(")");
3832     pos.Print();
3833 }
3834 
3835 ///////////////////////////////////////////////////////////////////////////
3836 // ExprList
3837 
HasAmbiguousVariability(std::vector<const Expr * > & warn) const3838 bool ExprList::HasAmbiguousVariability(std::vector<const Expr *> &warn) const {
3839     bool hasAmbiguousVariability = false;
3840     for (unsigned int i = 0; i < exprs.size(); ++i) {
3841         if (exprs[i] != NULL) {
3842             hasAmbiguousVariability |= exprs[i]->HasAmbiguousVariability(warn);
3843         }
3844     }
3845     return hasAmbiguousVariability;
3846 }
3847 
GetValue(FunctionEmitContext * ctx) const3848 llvm::Value *ExprList::GetValue(FunctionEmitContext *ctx) const {
3849     FATAL("ExprList::GetValue() should never be called");
3850     return NULL;
3851 }
3852 
GetType() const3853 const Type *ExprList::GetType() const {
3854     FATAL("ExprList::GetType() should never be called");
3855     return NULL;
3856 }
3857 
Optimize()3858 ExprList *ExprList::Optimize() { return this; }
3859 
TypeCheck()3860 ExprList *ExprList::TypeCheck() { return this; }
3861 
lGetExprListConstant(const Type * type,const ExprList * eList,bool isStorageType)3862 static std::pair<llvm::Constant *, bool> lGetExprListConstant(const Type *type, const ExprList *eList,
3863                                                               bool isStorageType) {
3864     std::vector<Expr *> exprs = eList->exprs;
3865     SourcePos pos = eList->pos;
3866     bool isVaryingInit = false;
3867     bool isNotValidForMultiTargetGlobal = false;
3868     if (exprs.size() == 1 && (CastType<AtomicType>(type) != NULL || CastType<EnumType>(type) != NULL ||
3869                               CastType<PointerType>(type) != NULL)) {
3870         if (isStorageType)
3871             return exprs[0]->GetStorageConstant(type);
3872         else
3873             return exprs[0]->GetConstant(type);
3874     }
3875 
3876     const CollectionType *collectionType = CastType<CollectionType>(type);
3877     if (collectionType == NULL) {
3878         if (type->IsVaryingType() == true) {
3879             isVaryingInit = true;
3880         } else
3881             return std::pair<llvm::Constant *, bool>(NULL, false);
3882     }
3883 
3884     std::string name;
3885     if (CastType<StructType>(type) != NULL)
3886         name = "struct";
3887     else if (CastType<ArrayType>(type) != NULL)
3888         name = "array";
3889     else if (CastType<VectorType>(type) != NULL)
3890         name = "vector";
3891     else if (isVaryingInit == true)
3892         name = "varying";
3893     else
3894         FATAL("Unexpected CollectionType in lGetExprListConstant");
3895 
3896     int elementCount = (isVaryingInit == true) ? g->target->getVectorWidth() : collectionType->GetElementCount();
3897     if ((int)exprs.size() > elementCount) {
3898         const Type *errType = (isVaryingInit == true) ? type : collectionType;
3899         Error(pos,
3900               "Initializer list for %s \"%s\" must have no more than %d "
3901               "elements (has %d).",
3902               name.c_str(), errType->GetString().c_str(), elementCount, (int)exprs.size());
3903         return std::pair<llvm::Constant *, bool>(NULL, false);
3904     } else if ((isVaryingInit == true) && ((int)exprs.size() < elementCount)) {
3905         Error(pos,
3906               "Initializer list for %s \"%s\" must have %d "
3907               "elements (has %d).",
3908               name.c_str(), type->GetString().c_str(), elementCount, (int)exprs.size());
3909         return std::pair<llvm::Constant *, bool>(NULL, false);
3910     }
3911 
3912     std::vector<llvm::Constant *> cv;
3913     for (unsigned int i = 0; i < exprs.size(); ++i) {
3914         if (exprs[i] == NULL)
3915             return std::pair<llvm::Constant *, bool>(NULL, false);
3916         const Type *elementType =
3917             (isVaryingInit == true) ? type->GetAsUniformType() : collectionType->GetElementType(i);
3918 
3919         Expr *expr = exprs[i];
3920 
3921         if (llvm::dyn_cast<ExprList>(expr) == NULL) {
3922             // If there's a simple type conversion from the type of this
3923             // expression to the type we need, then let the regular type
3924             // conversion machinery handle it.
3925             expr = TypeConvertExpr(exprs[i], elementType, "initializer list");
3926             if (expr == NULL) {
3927                 AssertPos(pos, m->errorCount > 0);
3928                 return std::pair<llvm::Constant *, bool>(NULL, false);
3929             }
3930             // Re-establish const-ness if possible
3931             expr = ::Optimize(expr);
3932         }
3933         std::pair<llvm::Constant *, bool> cPair;
3934         if (isStorageType)
3935             cPair = expr->GetStorageConstant(elementType);
3936         else
3937             cPair = expr->GetConstant(elementType);
3938         llvm::Constant *c = cPair.first;
3939         if (c == NULL)
3940             // If this list element couldn't convert to the right constant
3941             // type for the corresponding collection member, then give up.
3942             return std::pair<llvm::Constant *, bool>(NULL, false);
3943         isNotValidForMultiTargetGlobal = isNotValidForMultiTargetGlobal || cPair.second;
3944         cv.push_back(c);
3945     }
3946 
3947     // If there are too few, then treat missing ones as if they were zero
3948     if (isVaryingInit == false) {
3949         for (int i = (int)exprs.size(); i < collectionType->GetElementCount(); ++i) {
3950             const Type *elementType = collectionType->GetElementType(i);
3951             if (elementType == NULL) {
3952                 AssertPos(pos, m->errorCount > 0);
3953                 return std::pair<llvm::Constant *, bool>(NULL, false);
3954             }
3955             llvm::Type *llvmType = elementType->LLVMType(g->ctx);
3956             if (llvmType == NULL) {
3957                 AssertPos(pos, m->errorCount > 0);
3958                 return std::pair<llvm::Constant *, bool>(NULL, false);
3959             }
3960 
3961             llvm::Constant *c = llvm::Constant::getNullValue(llvmType);
3962             cv.push_back(c);
3963         }
3964     }
3965 
3966     if (CastType<StructType>(type) != NULL) {
3967         llvm::StructType *llvmStructType = llvm::dyn_cast<llvm::StructType>(collectionType->LLVMType(g->ctx));
3968         AssertPos(pos, llvmStructType != NULL);
3969         return std::pair<llvm::Constant *, bool>(llvm::ConstantStruct::get(llvmStructType, cv),
3970                                                  isNotValidForMultiTargetGlobal);
3971     } else {
3972         llvm::Type *lt = type->LLVMType(g->ctx);
3973         llvm::ArrayType *lat = llvm::dyn_cast<llvm::ArrayType>(lt);
3974         if (lat != NULL)
3975             return std::pair<llvm::Constant *, bool>(llvm::ConstantArray::get(lat, cv), isNotValidForMultiTargetGlobal);
3976         else if (type->IsVaryingType()) {
3977             // uniform short vector type
3978             llvm::VectorType *lvt = llvm::dyn_cast<llvm::VectorType>(lt);
3979             AssertPos(pos, lvt != NULL);
3980             int vectorWidth = g->target->getVectorWidth();
3981 
3982             while ((cv.size() % vectorWidth) != 0) {
3983                 cv.push_back(llvm::UndefValue::get(lvt->getElementType()));
3984             }
3985 
3986             return std::pair<llvm::Constant *, bool>(llvm::ConstantVector::get(cv), isNotValidForMultiTargetGlobal);
3987         } else {
3988             // uniform short vector type
3989             AssertPos(pos, type->IsUniformType() && CastType<VectorType>(type) != NULL);
3990 
3991             llvm::VectorType *lvt = llvm::dyn_cast<llvm::VectorType>(lt);
3992             AssertPos(pos, lvt != NULL);
3993 
3994             // Uniform short vectors are stored as vectors of length
3995             // rounded up to a power of 2 bits in size but not less then 128 bit.
3996             // So we add additional undef values here until we get the right size.
3997             const VectorType *vt = CastType<VectorType>(type);
3998             int vectorWidth = vt->getVectorMemoryCount();
3999 
4000             while ((cv.size() % vectorWidth) != 0) {
4001                 cv.push_back(llvm::UndefValue::get(lvt->getElementType()));
4002             }
4003 
4004             return std::pair<llvm::Constant *, bool>(llvm::ConstantVector::get(cv), isNotValidForMultiTargetGlobal);
4005         }
4006     }
4007     return std::pair<llvm::Constant *, bool>(NULL, false);
4008 }
4009 
GetStorageConstant(const Type * type) const4010 std::pair<llvm::Constant *, bool> ExprList::GetStorageConstant(const Type *type) const {
4011     return lGetExprListConstant(type, this, true);
4012 }
GetConstant(const Type * type) const4013 std::pair<llvm::Constant *, bool> ExprList::GetConstant(const Type *type) const {
4014     return lGetExprListConstant(type, this, false);
4015 }
4016 
EstimateCost() const4017 int ExprList::EstimateCost() const { return 0; }
4018 
Print() const4019 void ExprList::Print() const {
4020     printf("expr list (");
4021     for (unsigned int i = 0; i < exprs.size(); ++i) {
4022         if (exprs[i] != NULL)
4023             exprs[i]->Print();
4024         printf("%s", (i == exprs.size() - 1) ? ")" : ", ");
4025     }
4026     pos.Print();
4027 }
4028 
4029 ///////////////////////////////////////////////////////////////////////////
4030 // IndexExpr
4031 
IndexExpr(Expr * a,Expr * i,SourcePos p)4032 IndexExpr::IndexExpr(Expr *a, Expr *i, SourcePos p) : Expr(p, IndexExprID) {
4033     baseExpr = a;
4034     index = i;
4035     type = lvalueType = NULL;
4036 }
4037 
4038 /** When computing pointer values, we need to apply a per-lane offset when
4039     we have a varying pointer that is itself indexing into varying data.
4040     Consdier the following ispc code:
4041 
4042     uniform float u[] = ...;
4043     float v[] = ...;
4044     int index = ...;
4045     float a = u[index];
4046     float b = v[index];
4047 
4048     To compute the varying pointer that holds the addresses to load from
4049     for u[index], we basically just need to multiply index element-wise by
4050     sizeof(float) before doing the memory load.  For v[index], we need to
4051     do the same scaling but also need to add per-lane offsets <0,
4052     sizeof(float), 2*sizeof(float), ...> so that the i'th lane loads the
4053     i'th of the varying values at its index value.
4054 
4055     This function handles figuring out when this additional offset is
4056     needed and then incorporates it in the varying pointer value.
4057  */
lAddVaryingOffsetsIfNeeded(FunctionEmitContext * ctx,llvm::Value * ptr,const Type * ptrRefType)4058 static llvm::Value *lAddVaryingOffsetsIfNeeded(FunctionEmitContext *ctx, llvm::Value *ptr, const Type *ptrRefType) {
4059     if (CastType<ReferenceType>(ptrRefType) != NULL)
4060         // References are uniform pointers, so no offsetting is needed
4061         return ptr;
4062 
4063     const PointerType *ptrType = CastType<PointerType>(ptrRefType);
4064     Assert(ptrType != NULL);
4065     if (ptrType->IsUniformType() || ptrType->IsSlice())
4066         return ptr;
4067 
4068     const Type *baseType = ptrType->GetBaseType();
4069     if (baseType->IsVaryingType() == false)
4070         return ptr;
4071 
4072     // must be indexing into varying atomic, enum, or pointer types
4073     if (Type::IsBasicType(baseType) == false)
4074         return ptr;
4075 
4076     // Onward: compute the per lane offsets.
4077     llvm::Value *varyingOffsets = ctx->ProgramIndexVector();
4078 
4079     // And finally add the per-lane offsets.  Note that we lie to the GEP
4080     // call and tell it that the pointers are to uniform elements and not
4081     // varying elements, so that the offsets in terms of (0,1,2,...) will
4082     // end up turning into the correct step in bytes...
4083     const Type *uniformElementType = baseType->GetAsUniformType();
4084     const Type *ptrUnifType = PointerType::GetVarying(uniformElementType);
4085     return ctx->GetElementPtrInst(ptr, varyingOffsets, ptrUnifType);
4086 }
4087 
4088 /** Check to see if the given type is an array of or pointer to a varying
4089     struct type that in turn has a member with bound 'uniform' variability.
4090     Issue an error and return true if such a member is found.
4091  */
lVaryingStructHasUniformMember(const Type * type,SourcePos pos)4092 static bool lVaryingStructHasUniformMember(const Type *type, SourcePos pos) {
4093     if (CastType<VectorType>(type) != NULL || CastType<ReferenceType>(type) != NULL)
4094         return false;
4095 
4096     const StructType *st = CastType<StructType>(type);
4097     if (st == NULL) {
4098         const ArrayType *at = CastType<ArrayType>(type);
4099         if (at != NULL)
4100             st = CastType<StructType>(at->GetElementType());
4101         else {
4102             const PointerType *pt = CastType<PointerType>(type);
4103             if (pt == NULL)
4104                 return false;
4105 
4106             st = CastType<StructType>(pt->GetBaseType());
4107         }
4108 
4109         if (st == NULL)
4110             return false;
4111     }
4112 
4113     if (st->IsVaryingType() == false)
4114         return false;
4115 
4116     for (int i = 0; i < st->GetElementCount(); ++i) {
4117         const Type *eltType = st->GetElementType(i);
4118         if (eltType == NULL) {
4119             AssertPos(pos, m->errorCount > 0);
4120             continue;
4121         }
4122 
4123         if (CastType<StructType>(eltType) != NULL) {
4124             // We know that the enclosing struct is varying at this point,
4125             // so push that down to the enclosed struct before makign the
4126             // recursive call.
4127             eltType = eltType->GetAsVaryingType();
4128             if (lVaryingStructHasUniformMember(eltType, pos))
4129                 return true;
4130         } else if (eltType->IsUniformType()) {
4131             Error(pos,
4132                   "Gather operation is impossible due to the presence of "
4133                   "struct member \"%s\" with uniform type \"%s\" in the "
4134                   "varying struct type \"%s\".",
4135                   st->GetElementName(i).c_str(), eltType->GetString().c_str(), st->GetString().c_str());
4136             return true;
4137         }
4138     }
4139 
4140     return false;
4141 }
4142 
GetValue(FunctionEmitContext * ctx) const4143 llvm::Value *IndexExpr::GetValue(FunctionEmitContext *ctx) const {
4144     const Type *indexType, *returnType;
4145     if (baseExpr == NULL || index == NULL || ((indexType = index->GetType()) == NULL) ||
4146         ((returnType = GetType()) == NULL)) {
4147         AssertPos(pos, m->errorCount > 0);
4148         return NULL;
4149     }
4150 
4151     // If this is going to be a gather, make sure that the varying return
4152     // type can represent the result (i.e. that we don't have a bound
4153     // 'uniform' member in a varying struct...)
4154     if (indexType->IsVaryingType() && lVaryingStructHasUniformMember(returnType, pos))
4155         return NULL;
4156 
4157     ctx->SetDebugPos(pos);
4158 
4159     llvm::Value *ptr = GetLValue(ctx);
4160     llvm::Value *mask = NULL;
4161     const Type *lvType = GetLValueType();
4162     if (ptr == NULL) {
4163         // We may be indexing into a temporary that hasn't hit memory, so
4164         // get the full value and stuff it into temporary alloca'd space so
4165         // that we can index from there...
4166         const Type *baseExprType = baseExpr->GetType();
4167         llvm::Value *val = baseExpr->GetValue(ctx);
4168         if (baseExprType == NULL || val == NULL) {
4169             AssertPos(pos, m->errorCount > 0);
4170             return NULL;
4171         }
4172         ctx->SetDebugPos(pos);
4173         llvm::Value *tmpPtr = ctx->AllocaInst(baseExprType, "array_tmp");
4174         ctx->StoreInst(val, tmpPtr, baseExprType, baseExprType->IsUniformType());
4175 
4176         // Get a pointer type to the underlying elements
4177         const SequentialType *st = CastType<SequentialType>(baseExprType);
4178         if (st == NULL) {
4179             Assert(m->errorCount > 0);
4180             return NULL;
4181         }
4182         lvType = PointerType::GetUniform(st->GetElementType());
4183 
4184         // And do the indexing calculation into the temporary array in memory
4185         ptr = ctx->GetElementPtrInst(tmpPtr, LLVMInt32(0), index->GetValue(ctx), PointerType::GetUniform(baseExprType));
4186         ptr = lAddVaryingOffsetsIfNeeded(ctx, ptr, lvType);
4187 
4188         mask = LLVMMaskAllOn;
4189     } else {
4190         Symbol *baseSym = GetBaseSymbol();
4191         if (llvm::dyn_cast<FunctionCallExpr>(baseExpr) == NULL && llvm::dyn_cast<BinaryExpr>(baseExpr) == NULL) {
4192             // Don't check if we're doing a function call or pointer arith
4193             AssertPos(pos, baseSym != NULL);
4194         }
4195         mask = lMaskForSymbol(baseSym, ctx);
4196     }
4197 
4198     ctx->SetDebugPos(pos);
4199     return ctx->LoadInst(ptr, mask, lvType);
4200 }
4201 
GetType() const4202 const Type *IndexExpr::GetType() const {
4203     if (type != NULL)
4204         return type;
4205 
4206     const Type *baseExprType, *indexType;
4207     if (!baseExpr || !index || ((baseExprType = baseExpr->GetType()) == NULL) ||
4208         ((indexType = index->GetType()) == NULL))
4209         return NULL;
4210 
4211     const Type *elementType = NULL;
4212     const PointerType *pointerType = CastType<PointerType>(baseExprType);
4213     if (pointerType != NULL)
4214         // ptr[index] -> type that the pointer points to
4215         elementType = pointerType->GetBaseType();
4216     else if (const SequentialType *sequentialType = CastType<SequentialType>(baseExprType->GetReferenceTarget()))
4217         // sequential type[index] -> element type of the sequential type
4218         elementType = sequentialType->GetElementType();
4219     else
4220         // Not an expression that can be indexed into. Will result in error.
4221         return NULL;
4222 
4223     // If we're indexing into a sequence of SOA types, the result type is
4224     // actually the underlying type, as a uniform or varying.  Get the
4225     // uniform variant of it for starters, then below we'll make it varying
4226     // if the index is varying.
4227     // (If we ever provide a way to index into SOA types and get an entire
4228     // SOA'd struct out of the array, then we won't want to do this in that
4229     // case..)
4230     if (elementType->IsSOAType())
4231         elementType = elementType->GetAsUniformType();
4232 
4233     // If either the index is varying or we're indexing into a varying
4234     // pointer, then the result type is the varying variant of the indexed
4235     // type.
4236     if (indexType->IsUniformType() && (pointerType == NULL || pointerType->IsUniformType()))
4237         type = elementType;
4238     else
4239         type = elementType->GetAsVaryingType();
4240 
4241     return type;
4242 }
4243 
GetBaseSymbol() const4244 Symbol *IndexExpr::GetBaseSymbol() const { return baseExpr ? baseExpr->GetBaseSymbol() : NULL; }
4245 
4246 /** Utility routine that takes a regualr pointer (either uniform or
4247     varying) and returns a slice pointer with zero offsets.
4248  */
lConvertToSlicePointer(FunctionEmitContext * ctx,llvm::Value * ptr,const PointerType * slicePtrType)4249 static llvm::Value *lConvertToSlicePointer(FunctionEmitContext *ctx, llvm::Value *ptr,
4250                                            const PointerType *slicePtrType) {
4251     llvm::Type *llvmSlicePtrType = slicePtrType->LLVMType(g->ctx);
4252     llvm::StructType *sliceStructType = llvm::dyn_cast<llvm::StructType>(llvmSlicePtrType);
4253     Assert(sliceStructType != NULL && sliceStructType->getElementType(0) == ptr->getType());
4254 
4255     // Get a null-initialized struct to take care of having zeros for the
4256     // offsets
4257     llvm::Value *result = llvm::Constant::getNullValue(sliceStructType);
4258     // And replace the pointer in the struct with the given pointer
4259     return ctx->InsertInst(result, ptr, 0, llvm::Twine(ptr->getName()) + "_slice");
4260 }
4261 
4262 /** If the given array index is a compile time constant, check to see if it
4263     value/values don't go past the end of the array; issue a warning if
4264     so.
4265 */
lCheckIndicesVersusBounds(const Type * baseExprType,Expr * index)4266 static void lCheckIndicesVersusBounds(const Type *baseExprType, Expr *index) {
4267     const SequentialType *seqType = CastType<SequentialType>(baseExprType);
4268     if (seqType == NULL)
4269         return;
4270 
4271     int nElements = seqType->GetElementCount();
4272     if (nElements == 0)
4273         // Unsized array...
4274         return;
4275 
4276     // If it's an array of soa<> items, then the number of elements to
4277     // worry about w.r.t. index values is the product of the array size and
4278     // the soa width.
4279     int soaWidth = seqType->GetElementType()->GetSOAWidth();
4280     if (soaWidth > 0)
4281         nElements *= soaWidth;
4282 
4283     ConstExpr *ce = llvm::dyn_cast<ConstExpr>(index);
4284     if (ce == NULL)
4285         return;
4286 
4287     int32_t indices[ISPC_MAX_NVEC];
4288     int count = ce->GetValues(indices);
4289     for (int i = 0; i < count; ++i) {
4290         if (indices[i] < 0 || indices[i] >= nElements)
4291             Warning(index->pos,
4292                     "Array index \"%d\" may be out of bounds for %d "
4293                     "element array.",
4294                     indices[i], nElements);
4295     }
4296 }
4297 
4298 /** Converts the given pointer value to a slice pointer if the pointer
4299     points to SOA'ed data.
4300 */
lConvertPtrToSliceIfNeeded(FunctionEmitContext * ctx,llvm::Value * ptr,const Type ** type)4301 static llvm::Value *lConvertPtrToSliceIfNeeded(FunctionEmitContext *ctx, llvm::Value *ptr, const Type **type) {
4302     Assert(*type != NULL);
4303     const PointerType *ptrType = CastType<PointerType>(*type);
4304     Assert(ptrType != NULL);
4305     bool convertToSlice = (ptrType->GetBaseType()->IsSOAType() && ptrType->IsSlice() == false);
4306     if (convertToSlice == false)
4307         return ptr;
4308 
4309     *type = ptrType->GetAsSlice();
4310     return lConvertToSlicePointer(ctx, ptr, ptrType->GetAsSlice());
4311 }
4312 
GetLValue(FunctionEmitContext * ctx) const4313 llvm::Value *IndexExpr::GetLValue(FunctionEmitContext *ctx) const {
4314     const Type *baseExprType;
4315     if (baseExpr == NULL || index == NULL || ((baseExprType = baseExpr->GetType()) == NULL)) {
4316         AssertPos(pos, m->errorCount > 0);
4317         return NULL;
4318     }
4319 
4320     ctx->SetDebugPos(pos);
4321     llvm::Value *indexValue = index->GetValue(ctx);
4322     if (indexValue == NULL) {
4323         AssertPos(pos, m->errorCount > 0);
4324         return NULL;
4325     }
4326 
4327     ctx->SetDebugPos(pos);
4328     if (CastType<PointerType>(baseExprType) != NULL) {
4329         // We're indexing off of a pointer
4330         llvm::Value *basePtrValue = baseExpr->GetValue(ctx);
4331         if (basePtrValue == NULL) {
4332             AssertPos(pos, m->errorCount > 0);
4333             return NULL;
4334         }
4335         ctx->SetDebugPos(pos);
4336 
4337         // Convert to a slice pointer if we're indexing into SOA data
4338         basePtrValue = lConvertPtrToSliceIfNeeded(ctx, basePtrValue, &baseExprType);
4339 
4340         llvm::Value *ptr = ctx->GetElementPtrInst(basePtrValue, indexValue, baseExprType,
4341                                                   llvm::Twine(basePtrValue->getName()) + "_offset");
4342         return lAddVaryingOffsetsIfNeeded(ctx, ptr, GetLValueType());
4343     }
4344 
4345     // Not a pointer: we must be indexing an array or vector (and possibly
4346     // a reference thereuponfore.)
4347     llvm::Value *basePtr = NULL;
4348     const PointerType *basePtrType = NULL;
4349     if (CastType<ArrayType>(baseExprType) || CastType<VectorType>(baseExprType)) {
4350         basePtr = baseExpr->GetLValue(ctx);
4351         basePtrType = CastType<PointerType>(baseExpr->GetLValueType());
4352         if (baseExpr->GetLValueType())
4353             AssertPos(pos, basePtrType != NULL);
4354     } else {
4355         baseExprType = baseExprType->GetReferenceTarget();
4356         AssertPos(pos, CastType<ArrayType>(baseExprType) || CastType<VectorType>(baseExprType));
4357         basePtr = baseExpr->GetValue(ctx);
4358         basePtrType = PointerType::GetUniform(baseExprType);
4359     }
4360     if (!basePtr)
4361         return NULL;
4362 
4363     // If possible, check the index value(s) against the size of the array
4364     lCheckIndicesVersusBounds(baseExprType, index);
4365 
4366     // Convert to a slice pointer if indexing into SOA data
4367     basePtr = lConvertPtrToSliceIfNeeded(ctx, basePtr, (const Type **)&basePtrType);
4368 
4369     ctx->SetDebugPos(pos);
4370 
4371     // And do the actual indexing calculation..
4372     llvm::Value *ptr = ctx->GetElementPtrInst(basePtr, LLVMInt32(0), indexValue, basePtrType,
4373                                               llvm::Twine(basePtr->getName()) + "_offset");
4374     return lAddVaryingOffsetsIfNeeded(ctx, ptr, GetLValueType());
4375 }
4376 
GetLValueType() const4377 const Type *IndexExpr::GetLValueType() const {
4378     if (lvalueType != NULL)
4379         return lvalueType;
4380 
4381     const Type *baseExprType, *baseExprLValueType, *indexType;
4382     if (baseExpr == NULL || index == NULL || ((baseExprType = baseExpr->GetType()) == NULL) ||
4383         ((baseExprLValueType = baseExpr->GetLValueType()) == NULL) || ((indexType = index->GetType()) == NULL))
4384         return NULL;
4385 
4386     // regularize to a PointerType
4387     if (CastType<ReferenceType>(baseExprLValueType) != NULL) {
4388         const Type *refTarget = baseExprLValueType->GetReferenceTarget();
4389         baseExprLValueType = PointerType::GetUniform(refTarget);
4390     }
4391     AssertPos(pos, CastType<PointerType>(baseExprLValueType) != NULL);
4392 
4393     // Find the type of thing that we're indexing into
4394     const Type *elementType;
4395     const SequentialType *st = CastType<SequentialType>(baseExprLValueType->GetBaseType());
4396     if (st != NULL)
4397         elementType = st->GetElementType();
4398     else {
4399         const PointerType *pt = CastType<PointerType>(baseExprLValueType->GetBaseType());
4400         // This assertion seems overly strict.
4401         // Why does it need to be a pointer to a pointer?
4402         // AssertPos(pos, pt != NULL);
4403 
4404         if (pt != NULL) {
4405             elementType = pt->GetBaseType();
4406         } else {
4407             elementType = baseExprLValueType->GetBaseType();
4408         }
4409     }
4410 
4411     // Are we indexing into a varying type, or are we indexing with a
4412     // varying pointer?
4413     bool baseVarying;
4414     if (CastType<PointerType>(baseExprType) != NULL)
4415         baseVarying = baseExprType->IsVaryingType();
4416     else
4417         baseVarying = baseExprLValueType->IsVaryingType();
4418 
4419     // The return type is uniform iff. the base is a uniform pointer / a
4420     // collection of uniform typed elements and the index is uniform.
4421     if (baseVarying == false && indexType->IsUniformType())
4422         lvalueType = PointerType::GetUniform(elementType);
4423     else
4424         lvalueType = PointerType::GetVarying(elementType);
4425 
4426     // Finally, if we're indexing into an SOA type, then the resulting
4427     // pointer must (currently) be a slice pointer; we don't allow indexing
4428     // the soa-width-wide structs directly.
4429     if (elementType->IsSOAType())
4430         lvalueType = lvalueType->GetAsSlice();
4431 
4432     return lvalueType;
4433 }
4434 
Optimize()4435 Expr *IndexExpr::Optimize() {
4436     if (baseExpr == NULL || index == NULL)
4437         return NULL;
4438     return this;
4439 }
4440 
TypeCheck()4441 Expr *IndexExpr::TypeCheck() {
4442     const Type *indexType;
4443     if (baseExpr == NULL || index == NULL || ((indexType = index->GetType()) == NULL)) {
4444         AssertPos(pos, m->errorCount > 0);
4445         return NULL;
4446     }
4447 
4448     const Type *baseExprType = baseExpr->GetType();
4449     if (baseExprType == NULL) {
4450         AssertPos(pos, m->errorCount > 0);
4451         return NULL;
4452     }
4453 
4454     if (!CastType<SequentialType>(baseExprType->GetReferenceTarget())) {
4455         if (const PointerType *pt = CastType<PointerType>(baseExprType)) {
4456             if (pt->GetBaseType()->IsVoidType()) {
4457                 Error(pos, "Illegal to dereference void pointer type \"%s\".", baseExprType->GetString().c_str());
4458                 return NULL;
4459             }
4460         } else {
4461             Error(pos,
4462                   "Trying to index into non-array, vector, or pointer "
4463                   "type \"%s\".",
4464                   baseExprType->GetString().c_str());
4465             return NULL;
4466         }
4467     }
4468 
4469     bool isUniform = (index->GetType()->IsUniformType() && !g->opt.disableUniformMemoryOptimizations);
4470 
4471     if (!isUniform) {
4472         // Unless we have an explicit 64-bit index and are compiling to a
4473         // 64-bit target with 64-bit addressing, convert the index to an int32
4474         // type.
4475         //    The range of varying index is limited to [0,2^31) as a result.
4476         if (!(Type::EqualIgnoringConst(indexType->GetAsUniformType(), AtomicType::UniformUInt64) ||
4477               Type::EqualIgnoringConst(indexType->GetAsUniformType(), AtomicType::UniformInt64)) ||
4478             g->target->is32Bit() || g->opt.force32BitAddressing) {
4479             const Type *indexType = AtomicType::VaryingInt32;
4480             index = TypeConvertExpr(index, indexType, "array index");
4481             if (index == NULL)
4482                 return NULL;
4483         }
4484     } else { // isUniform
4485         // For 32-bit target:
4486         //   force the index to 32 bit.
4487         // For 64-bit target:
4488         //   We don't want to limit the index range.
4489         //   We sxt/zxt the index to 64 bit right here because
4490         //   LLVM doesn't distinguish unsigned from signed (both are i32)
4491         //
4492         //   However, the index can be still truncated to signed int32 if
4493         //   the index type is 64 bit and --addressing=32.
4494         bool force_32bit =
4495             g->target->is32Bit() || (g->opt.force32BitAddressing &&
4496                                      Type::EqualIgnoringConst(indexType->GetAsUniformType(), AtomicType::UniformInt64));
4497         const Type *indexType = force_32bit ? AtomicType::UniformInt32 : AtomicType::UniformInt64;
4498         index = TypeConvertExpr(index, indexType, "array index");
4499         if (index == NULL)
4500             return NULL;
4501     }
4502 
4503     return this;
4504 }
4505 
EstimateCost() const4506 int IndexExpr::EstimateCost() const {
4507     if (index == NULL || baseExpr == NULL)
4508         return 0;
4509 
4510     const Type *indexType = index->GetType();
4511     const Type *baseExprType = baseExpr->GetType();
4512 
4513     if ((indexType != NULL && indexType->IsVaryingType()) ||
4514         (CastType<PointerType>(baseExprType) != NULL && baseExprType->IsVaryingType()))
4515         // be pessimistic; some of these will later turn out to be vector
4516         // loads/stores, but it's too early for us to know that here.
4517         return COST_GATHER;
4518     else
4519         return COST_LOAD;
4520 }
4521 
Print() const4522 void IndexExpr::Print() const {
4523     if (!baseExpr || !index || !GetType())
4524         return;
4525 
4526     printf("[%s] index ", GetType()->GetString().c_str());
4527     baseExpr->Print();
4528     printf("[");
4529     index->Print();
4530     printf("]");
4531     pos.Print();
4532 }
4533 
4534 ///////////////////////////////////////////////////////////////////////////
4535 // MemberExpr
4536 
4537 /** Map one character ids to vector element numbers.  Allow a few different
4538     conventions--xyzw, rgba, uv.
4539 */
lIdentifierToVectorElement(char id)4540 static int lIdentifierToVectorElement(char id) {
4541     switch (id) {
4542     case 'x':
4543     case 'r':
4544     case 'u':
4545         return 0;
4546     case 'y':
4547     case 'g':
4548     case 'v':
4549         return 1;
4550     case 'z':
4551     case 'b':
4552         return 2;
4553     case 'w':
4554     case 'a':
4555         return 3;
4556     default:
4557         return -1;
4558     }
4559 }
4560 
4561 //////////////////////////////////////////////////
4562 // StructMemberExpr
4563 
4564 class StructMemberExpr : public MemberExpr {
4565   public:
4566     StructMemberExpr(Expr *e, const char *id, SourcePos p, SourcePos idpos, bool derefLValue);
4567 
classof(StructMemberExpr const *)4568     static inline bool classof(StructMemberExpr const *) { return true; }
classof(ASTNode const * N)4569     static inline bool classof(ASTNode const *N) { return N->getValueID() == StructMemberExprID; }
4570 
4571     const Type *GetType() const;
4572     const Type *GetLValueType() const;
4573     int getElementNumber() const;
4574     const Type *getElementType() const;
4575 
4576   private:
4577     const StructType *getStructType() const;
4578 };
4579 
StructMemberExpr(Expr * e,const char * id,SourcePos p,SourcePos idpos,bool derefLValue)4580 StructMemberExpr::StructMemberExpr(Expr *e, const char *id, SourcePos p, SourcePos idpos, bool derefLValue)
4581     : MemberExpr(e, id, p, idpos, derefLValue, StructMemberExprID) {}
4582 
GetType() const4583 const Type *StructMemberExpr::GetType() const {
4584     if (type != NULL)
4585         return type;
4586 
4587     // It's a struct, and the result type is the element type, possibly
4588     // promoted to varying if the struct type / lvalue is varying.
4589     const Type *exprType, *lvalueType;
4590     const StructType *structType;
4591     if (expr == NULL || ((exprType = expr->GetType()) == NULL) || ((structType = getStructType()) == NULL) ||
4592         ((lvalueType = GetLValueType()) == NULL)) {
4593         AssertPos(pos, m->errorCount > 0);
4594         return NULL;
4595     }
4596 
4597     const Type *elementType = structType->GetElementType(identifier);
4598     if (elementType == NULL) {
4599         Error(identifierPos, "Element name \"%s\" not present in struct type \"%s\".%s", identifier.c_str(),
4600               structType->GetString().c_str(), getCandidateNearMatches().c_str());
4601         return NULL;
4602     }
4603     AssertPos(pos, Type::Equal(lvalueType->GetBaseType(), elementType));
4604 
4605     bool isSlice = (CastType<PointerType>(lvalueType) && CastType<PointerType>(lvalueType)->IsSlice());
4606     if (isSlice) {
4607         // FIXME: not true if we allow bound unif/varying for soa<>
4608         // structs?...
4609         AssertPos(pos, elementType->IsSOAType());
4610 
4611         // If we're accessing a member of an soa structure via a uniform
4612         // slice pointer, then the result type is the uniform variant of
4613         // the element type.
4614         if (lvalueType->IsUniformType())
4615             elementType = elementType->GetAsUniformType();
4616     }
4617 
4618     if (lvalueType->IsVaryingType())
4619         // If the expression we're getting the member of has an lvalue that
4620         // is a varying pointer type (be it slice or non-slice), then the
4621         // result type must be the varying version of the element type.
4622         elementType = elementType->GetAsVaryingType();
4623 
4624     type = elementType;
4625     return type;
4626 }
4627 
GetLValueType() const4628 const Type *StructMemberExpr::GetLValueType() const {
4629     if (lvalueType != NULL)
4630         return lvalueType;
4631 
4632     if (expr == NULL) {
4633         AssertPos(pos, m->errorCount > 0);
4634         return NULL;
4635     }
4636 
4637     const Type *exprLValueType = dereferenceExpr ? expr->GetType() : expr->GetLValueType();
4638     if (exprLValueType == NULL) {
4639         AssertPos(pos, m->errorCount > 0);
4640         return NULL;
4641     }
4642 
4643     // The pointer type is varying if the lvalue type of the expression is
4644     // varying (and otherwise uniform)
4645     const PointerType *ptrType = (exprLValueType->IsUniformType() || CastType<ReferenceType>(exprLValueType) != NULL)
4646                                      ? PointerType::GetUniform(getElementType())
4647                                      : PointerType::GetVarying(getElementType());
4648 
4649     // If struct pointer is a slice pointer, the resulting member pointer
4650     // needs to be a frozen slice pointer--i.e. any further indexing with
4651     // the result shouldn't modify the minor slice offset, but it should be
4652     // left unchanged until we get to a leaf SOA value.
4653     if (CastType<PointerType>(exprLValueType) && CastType<PointerType>(exprLValueType)->IsSlice())
4654         ptrType = ptrType->GetAsFrozenSlice();
4655 
4656     lvalueType = ptrType;
4657     return lvalueType;
4658 }
4659 
getElementNumber() const4660 int StructMemberExpr::getElementNumber() const {
4661     const StructType *structType = getStructType();
4662     if (structType == NULL)
4663         return -1;
4664 
4665     int elementNumber = structType->GetElementNumber(identifier);
4666     if (elementNumber == -1)
4667         Error(identifierPos, "Element name \"%s\" not present in struct type \"%s\".%s", identifier.c_str(),
4668               structType->GetString().c_str(), getCandidateNearMatches().c_str());
4669 
4670     return elementNumber;
4671 }
4672 
getElementType() const4673 const Type *StructMemberExpr::getElementType() const {
4674     const StructType *structType = getStructType();
4675     if (structType == NULL)
4676         return NULL;
4677 
4678     return structType->GetElementType(identifier);
4679 }
4680 
4681 /** Returns the type of the underlying struct that we're returning a member
4682     of. */
getStructType() const4683 const StructType *StructMemberExpr::getStructType() const {
4684     const Type *type = dereferenceExpr ? expr->GetType() : expr->GetLValueType();
4685     if (type == NULL)
4686         return NULL;
4687 
4688     const Type *structType;
4689     const ReferenceType *rt = CastType<ReferenceType>(type);
4690     if (rt != NULL)
4691         structType = rt->GetReferenceTarget();
4692     else {
4693         const PointerType *pt = CastType<PointerType>(type);
4694         AssertPos(pos, pt != NULL);
4695         structType = pt->GetBaseType();
4696     }
4697 
4698     const StructType *ret = CastType<StructType>(structType);
4699     AssertPos(pos, ret != NULL);
4700     return ret;
4701 }
4702 
4703 //////////////////////////////////////////////////
4704 // VectorMemberExpr
4705 
4706 class VectorMemberExpr : public MemberExpr {
4707   public:
4708     VectorMemberExpr(Expr *e, const char *id, SourcePos p, SourcePos idpos, bool derefLValue);
4709 
classof(VectorMemberExpr const *)4710     static inline bool classof(VectorMemberExpr const *) { return true; }
classof(ASTNode const * N)4711     static inline bool classof(ASTNode const *N) { return N->getValueID() == VectorMemberExprID; }
4712 
4713     llvm::Value *GetValue(FunctionEmitContext *ctx) const;
4714     llvm::Value *GetLValue(FunctionEmitContext *ctx) const;
4715     const Type *GetType() const;
4716     const Type *GetLValueType() const;
4717 
4718     int getElementNumber() const;
4719     const Type *getElementType() const;
4720 
4721   private:
4722     const VectorType *exprVectorType;
4723     const VectorType *memberType;
4724 };
4725 
VectorMemberExpr(Expr * e,const char * id,SourcePos p,SourcePos idpos,bool derefLValue)4726 VectorMemberExpr::VectorMemberExpr(Expr *e, const char *id, SourcePos p, SourcePos idpos, bool derefLValue)
4727     : MemberExpr(e, id, p, idpos, derefLValue, VectorMemberExprID) {
4728     const Type *exprType = e->GetType();
4729     exprVectorType = CastType<VectorType>(exprType);
4730     if (exprVectorType == NULL) {
4731         const PointerType *pt = CastType<PointerType>(exprType);
4732         if (pt != NULL)
4733             exprVectorType = CastType<VectorType>(pt->GetBaseType());
4734         else {
4735             AssertPos(pos, CastType<ReferenceType>(exprType) != NULL);
4736             exprVectorType = CastType<VectorType>(exprType->GetReferenceTarget());
4737         }
4738         AssertPos(pos, exprVectorType != NULL);
4739     }
4740     memberType = new VectorType(exprVectorType->GetElementType(), identifier.length());
4741 }
4742 
GetType() const4743 const Type *VectorMemberExpr::GetType() const {
4744     if (type != NULL)
4745         return type;
4746 
4747     // For 1-element expressions, we have the base vector element
4748     // type.  For n-element expressions, we have a shortvec type
4749     // with n > 1 elements.  This can be changed when we get
4750     // type<1> -> type conversions.
4751     type = (identifier.length() == 1) ? (const Type *)exprVectorType->GetElementType() : (const Type *)memberType;
4752 
4753     const Type *lvType = GetLValueType();
4754     if (lvType != NULL) {
4755         bool isSlice = (CastType<PointerType>(lvType) && CastType<PointerType>(lvType)->IsSlice());
4756         if (isSlice) {
4757             // CO            AssertPos(pos, type->IsSOAType());
4758             if (lvType->IsUniformType())
4759                 type = type->GetAsUniformType();
4760         }
4761 
4762         if (lvType->IsVaryingType())
4763             type = type->GetAsVaryingType();
4764     }
4765 
4766     return type;
4767 }
4768 
GetLValue(FunctionEmitContext * ctx) const4769 llvm::Value *VectorMemberExpr::GetLValue(FunctionEmitContext *ctx) const {
4770     if (identifier.length() == 1) {
4771         return MemberExpr::GetLValue(ctx);
4772     } else {
4773         return NULL;
4774     }
4775 }
4776 
GetLValueType() const4777 const Type *VectorMemberExpr::GetLValueType() const {
4778     if (lvalueType != NULL)
4779         return lvalueType;
4780 
4781     if (identifier.length() == 1) {
4782         if (expr == NULL) {
4783             AssertPos(pos, m->errorCount > 0);
4784             return NULL;
4785         }
4786 
4787         const Type *exprLValueType = dereferenceExpr ? expr->GetType() : expr->GetLValueType();
4788         if (exprLValueType == NULL)
4789             return NULL;
4790 
4791         const VectorType *vt = NULL;
4792         if (CastType<ReferenceType>(exprLValueType) != NULL)
4793             vt = CastType<VectorType>(exprLValueType->GetReferenceTarget());
4794         else
4795             vt = CastType<VectorType>(exprLValueType->GetBaseType());
4796         AssertPos(pos, vt != NULL);
4797 
4798         // we don't want to report that it's e.g. a pointer to a float<1>,
4799         // but a pointer to a float, etc.
4800         const Type *elementType = vt->GetElementType();
4801         if (CastType<ReferenceType>(exprLValueType) != NULL)
4802             lvalueType = new ReferenceType(elementType);
4803         else {
4804             const PointerType *ptrType = exprLValueType->IsUniformType() ? PointerType::GetUniform(elementType)
4805                                                                          : PointerType::GetVarying(elementType);
4806             // FIXME: replicated logic with structmemberexpr....
4807             if (CastType<PointerType>(exprLValueType) && CastType<PointerType>(exprLValueType)->IsSlice())
4808                 ptrType = ptrType->GetAsFrozenSlice();
4809             lvalueType = ptrType;
4810         }
4811     }
4812 
4813     return lvalueType;
4814 }
4815 
GetValue(FunctionEmitContext * ctx) const4816 llvm::Value *VectorMemberExpr::GetValue(FunctionEmitContext *ctx) const {
4817     if (identifier.length() == 1) {
4818         return MemberExpr::GetValue(ctx);
4819     } else {
4820         std::vector<int> indices;
4821 
4822         for (size_t i = 0; i < identifier.size(); ++i) {
4823             int idx = lIdentifierToVectorElement(identifier[i]);
4824             if (idx == -1)
4825                 Error(pos, "Invalid swizzle character '%c' in swizzle \"%s\".", identifier[i], identifier.c_str());
4826 
4827             indices.push_back(idx);
4828         }
4829 
4830         llvm::Value *basePtr = NULL;
4831         const Type *basePtrType = NULL;
4832         if (dereferenceExpr) {
4833             basePtr = expr->GetValue(ctx);
4834             basePtrType = expr->GetType();
4835         } else {
4836             basePtr = expr->GetLValue(ctx);
4837             basePtrType = expr->GetLValueType();
4838         }
4839 
4840         if (basePtr == NULL || basePtrType == NULL) {
4841             // Check that expression on the left side is a rvalue expression
4842             llvm::Value *exprValue = expr->GetValue(ctx);
4843             basePtr = ctx->AllocaInst(expr->GetType());
4844             basePtrType = PointerType::GetUniform(exprVectorType);
4845             if (basePtr == NULL || basePtrType == NULL) {
4846                 AssertPos(pos, m->errorCount > 0);
4847                 return NULL;
4848             }
4849             ctx->StoreInst(exprValue, basePtr, expr->GetType(), expr->GetType()->IsUniformType());
4850         }
4851 
4852         // Allocate temporary memory to store the result
4853         llvm::Value *resultPtr = ctx->AllocaInst(memberType, "vector_tmp");
4854 
4855         if (resultPtr == NULL) {
4856             AssertPos(pos, m->errorCount > 0);
4857             return NULL;
4858         }
4859 
4860         // FIXME: we should be able to use the internal mask here according
4861         // to the same logic where it's used elsewhere
4862         llvm::Value *elementMask = ctx->GetFullMask();
4863 
4864         const Type *elementPtrType = NULL;
4865         if (CastType<ReferenceType>(basePtrType) != NULL)
4866             elementPtrType = PointerType::GetUniform(basePtrType->GetReferenceTarget());
4867         else
4868             elementPtrType = basePtrType->IsUniformType() ? PointerType::GetUniform(exprVectorType->GetElementType())
4869                                                           : PointerType::GetVarying(exprVectorType->GetElementType());
4870 
4871         ctx->SetDebugPos(pos);
4872         for (size_t i = 0; i < identifier.size(); ++i) {
4873             char idStr[2] = {identifier[i], '\0'};
4874             llvm::Value *elementPtr =
4875                 ctx->AddElementOffset(basePtr, indices[i], basePtrType, llvm::Twine(basePtr->getName()) + idStr);
4876             llvm::Value *elementValue = ctx->LoadInst(elementPtr, elementMask, elementPtrType);
4877 
4878             llvm::Value *ptmp = ctx->AddElementOffset(resultPtr, i, NULL, llvm::Twine(resultPtr->getName()) + idStr);
4879             ctx->StoreInst(elementValue, ptmp, elementPtrType, expr->GetType()->IsUniformType());
4880         }
4881 
4882         return ctx->LoadInst(resultPtr, memberType, llvm::Twine(basePtr->getName()) + "_swizzle");
4883     }
4884 }
4885 
getElementNumber() const4886 int VectorMemberExpr::getElementNumber() const {
4887     int elementNumber = lIdentifierToVectorElement(identifier[0]);
4888     if (elementNumber == -1)
4889         Error(pos, "Vector element identifier \"%s\" unknown.", identifier.c_str());
4890     return elementNumber;
4891 }
4892 
getElementType() const4893 const Type *VectorMemberExpr::getElementType() const { return memberType; }
4894 
create(Expr * e,const char * id,SourcePos p,SourcePos idpos,bool derefLValue)4895 MemberExpr *MemberExpr::create(Expr *e, const char *id, SourcePos p, SourcePos idpos, bool derefLValue) {
4896     // FIXME: we need to call TypeCheck() here so that we can call
4897     // e->GetType() in the following.  But really we just shouldn't try to
4898     // resolve this now but just have a generic MemberExpr type that
4899     // handles all cases so that this is unnecessary.
4900     e = ::TypeCheck(e);
4901 
4902     const Type *exprType;
4903     if (e == NULL || (exprType = e->GetType()) == NULL)
4904         return NULL;
4905 
4906     const ReferenceType *referenceType = CastType<ReferenceType>(exprType);
4907     if (referenceType != NULL) {
4908         e = new RefDerefExpr(e, e->pos);
4909         exprType = e->GetType();
4910         Assert(exprType != NULL);
4911     }
4912 
4913     const PointerType *pointerType = CastType<PointerType>(exprType);
4914     if (pointerType != NULL)
4915         exprType = pointerType->GetBaseType();
4916 
4917     if (derefLValue == true && pointerType == NULL) {
4918         const Type *targetType = exprType->GetReferenceTarget();
4919         if (CastType<StructType>(targetType) != NULL)
4920             Error(p,
4921                   "Member operator \"->\" can't be applied to non-pointer "
4922                   "type \"%s\".  Did you mean to use \".\"?",
4923                   exprType->GetString().c_str());
4924         else
4925             Error(p,
4926                   "Member operator \"->\" can't be applied to non-struct "
4927                   "pointer type \"%s\".",
4928                   exprType->GetString().c_str());
4929         return NULL;
4930     }
4931     // For struct and short-vector, emit error if elements are accessed
4932     // incorrectly.
4933     if (derefLValue == false && pointerType != NULL &&
4934         ((CastType<StructType>(pointerType->GetBaseType()) != NULL) ||
4935          (CastType<VectorType>(pointerType->GetBaseType()) != NULL))) {
4936         Error(p,
4937               "Member operator \".\" can't be applied to pointer "
4938               "type \"%s\".  Did you mean to use \"->\"?",
4939               exprType->GetString().c_str());
4940         return NULL;
4941     }
4942     if (CastType<StructType>(exprType) != NULL) {
4943         const StructType *st = CastType<StructType>(exprType);
4944         if (st->IsDefined()) {
4945             return new StructMemberExpr(e, id, p, idpos, derefLValue);
4946         } else {
4947             Error(p,
4948                   "Member operator \"%s\" can't be applied to declared "
4949                   "struct \"%s\" containing an undefined struct type.",
4950                   derefLValue ? "->" : ".", exprType->GetString().c_str());
4951             return NULL;
4952         }
4953     } else if (CastType<VectorType>(exprType) != NULL)
4954         return new VectorMemberExpr(e, id, p, idpos, derefLValue);
4955     else if (CastType<UndefinedStructType>(exprType)) {
4956         Error(p,
4957               "Member operator \"%s\" can't be applied to declared "
4958               "but not defined struct type \"%s\".",
4959               derefLValue ? "->" : ".", exprType->GetString().c_str());
4960         return NULL;
4961     } else {
4962         Error(p,
4963               "Member operator \"%s\" can't be used with expression of "
4964               "\"%s\" type.",
4965               derefLValue ? "->" : ".", exprType->GetString().c_str());
4966         return NULL;
4967     }
4968 }
4969 
MemberExpr(Expr * e,const char * id,SourcePos p,SourcePos idpos,bool derefLValue,unsigned scid)4970 MemberExpr::MemberExpr(Expr *e, const char *id, SourcePos p, SourcePos idpos, bool derefLValue, unsigned scid)
4971     : Expr(p, scid), identifierPos(idpos) {
4972     expr = e;
4973     identifier = id;
4974     dereferenceExpr = derefLValue;
4975     type = lvalueType = NULL;
4976 }
4977 
GetValue(FunctionEmitContext * ctx) const4978 llvm::Value *MemberExpr::GetValue(FunctionEmitContext *ctx) const {
4979     if (!expr)
4980         return NULL;
4981 
4982     llvm::Value *lvalue = GetLValue(ctx);
4983     const Type *lvalueType = GetLValueType();
4984 
4985     llvm::Value *mask = NULL;
4986     if (lvalue == NULL) {
4987         if (m->errorCount > 0)
4988             return NULL;
4989 
4990         // As in the array case, this may be a temporary that hasn't hit
4991         // memory; get the full value and stuff it into a temporary array
4992         // so that we can index from there...
4993         llvm::Value *val = expr->GetValue(ctx);
4994         if (!val) {
4995             AssertPos(pos, m->errorCount > 0);
4996             return NULL;
4997         }
4998         ctx->SetDebugPos(pos);
4999         const Type *exprType = expr->GetType();
5000         llvm::Value *ptr = ctx->AllocaInst(exprType, "struct_tmp");
5001         ctx->StoreInst(val, ptr, exprType, exprType->IsUniformType());
5002 
5003         int elementNumber = getElementNumber();
5004         if (elementNumber == -1)
5005             return NULL;
5006 
5007         lvalue = ctx->AddElementOffset(ptr, elementNumber, PointerType::GetUniform(exprType));
5008         lvalueType = PointerType::GetUniform(GetType());
5009         mask = LLVMMaskAllOn;
5010     } else {
5011         Symbol *baseSym = GetBaseSymbol();
5012         AssertPos(pos, baseSym != NULL);
5013         mask = lMaskForSymbol(baseSym, ctx);
5014     }
5015 
5016     ctx->SetDebugPos(pos);
5017     std::string suffix = std::string("_") + identifier;
5018     return ctx->LoadInst(lvalue, mask, lvalueType, llvm::Twine(lvalue->getName()) + suffix);
5019 }
5020 
GetType() const5021 const Type *MemberExpr::GetType() const { return NULL; }
5022 
GetBaseSymbol() const5023 Symbol *MemberExpr::GetBaseSymbol() const { return expr ? expr->GetBaseSymbol() : NULL; }
5024 
getElementNumber() const5025 int MemberExpr::getElementNumber() const { return -1; }
5026 
GetLValue(FunctionEmitContext * ctx) const5027 llvm::Value *MemberExpr::GetLValue(FunctionEmitContext *ctx) const {
5028     const Type *exprType;
5029     if (!expr || ((exprType = expr->GetType()) == NULL))
5030         return NULL;
5031 
5032     ctx->SetDebugPos(pos);
5033     llvm::Value *basePtr = dereferenceExpr ? expr->GetValue(ctx) : expr->GetLValue(ctx);
5034     if (!basePtr)
5035         return NULL;
5036 
5037     int elementNumber = getElementNumber();
5038     if (elementNumber == -1)
5039         return NULL;
5040 
5041     const Type *exprLValueType = dereferenceExpr ? exprType : expr->GetLValueType();
5042     ctx->SetDebugPos(pos);
5043     llvm::Value *ptr = ctx->AddElementOffset(basePtr, elementNumber, exprLValueType, basePtr->getName().str().c_str());
5044     if (ptr == NULL) {
5045         AssertPos(pos, m->errorCount > 0);
5046         return NULL;
5047     }
5048 
5049     ptr = lAddVaryingOffsetsIfNeeded(ctx, ptr, GetLValueType());
5050 
5051     return ptr;
5052 }
5053 
TypeCheck()5054 Expr *MemberExpr::TypeCheck() { return expr ? this : NULL; }
5055 
Optimize()5056 Expr *MemberExpr::Optimize() { return expr ? this : NULL; }
5057 
EstimateCost() const5058 int MemberExpr::EstimateCost() const {
5059     const Type *lvalueType = GetLValueType();
5060     if (lvalueType != NULL && lvalueType->IsVaryingType())
5061         return COST_GATHER + COST_SIMPLE_ARITH_LOGIC_OP;
5062     else
5063         return COST_SIMPLE_ARITH_LOGIC_OP;
5064 }
5065 
Print() const5066 void MemberExpr::Print() const {
5067     if (!expr || !GetType())
5068         return;
5069 
5070     printf("[%s] member (", GetType()->GetString().c_str());
5071     expr->Print();
5072     printf(" . %s)", identifier.c_str());
5073     pos.Print();
5074 }
5075 
5076 /** There is no structure member with the name we've got in "identifier".
5077     Use the approximate string matching routine to see if the identifier is
5078     a minor misspelling of one of the ones that is there.
5079  */
getCandidateNearMatches() const5080 std::string MemberExpr::getCandidateNearMatches() const {
5081     const StructType *structType = CastType<StructType>(expr->GetType());
5082     if (!structType)
5083         return "";
5084 
5085     std::vector<std::string> elementNames;
5086     for (int i = 0; i < structType->GetElementCount(); ++i)
5087         elementNames.push_back(structType->GetElementName(i));
5088     std::vector<std::string> alternates = MatchStrings(identifier, elementNames);
5089     if (!alternates.size())
5090         return "";
5091 
5092     std::string ret = " Did you mean ";
5093     for (unsigned int i = 0; i < alternates.size(); ++i) {
5094         ret += std::string("\"") + alternates[i] + std::string("\"");
5095         if (i < alternates.size() - 1)
5096             ret += ", or ";
5097     }
5098     ret += "?";
5099     return ret;
5100 }
5101 
5102 ///////////////////////////////////////////////////////////////////////////
5103 // ConstExpr
5104 
ConstExpr(const Type * t,int8_t i,SourcePos p)5105 ConstExpr::ConstExpr(const Type *t, int8_t i, SourcePos p) : Expr(p, ConstExprID) {
5106     type = t;
5107     type = type->GetAsConstType();
5108     AssertPos(pos, Type::Equal(type, AtomicType::UniformInt8->GetAsConstType()));
5109     int8Val[0] = i;
5110 }
5111 
ConstExpr(const Type * t,int8_t * i,SourcePos p)5112 ConstExpr::ConstExpr(const Type *t, int8_t *i, SourcePos p) : Expr(p, ConstExprID) {
5113     type = t;
5114     type = type->GetAsConstType();
5115     AssertPos(pos, Type::Equal(type, AtomicType::UniformInt8->GetAsConstType()) ||
5116                        Type::Equal(type, AtomicType::VaryingInt8->GetAsConstType()));
5117     for (int j = 0; j < Count(); ++j)
5118         int8Val[j] = i[j];
5119 }
5120 
ConstExpr(const Type * t,uint8_t u,SourcePos p)5121 ConstExpr::ConstExpr(const Type *t, uint8_t u, SourcePos p) : Expr(p, ConstExprID) {
5122     type = t;
5123     type = type->GetAsConstType();
5124     AssertPos(pos, Type::Equal(type, AtomicType::UniformUInt8->GetAsConstType()));
5125     uint8Val[0] = u;
5126 }
5127 
ConstExpr(const Type * t,uint8_t * u,SourcePos p)5128 ConstExpr::ConstExpr(const Type *t, uint8_t *u, SourcePos p) : Expr(p, ConstExprID) {
5129     type = t;
5130     type = type->GetAsConstType();
5131     AssertPos(pos, Type::Equal(type, AtomicType::UniformUInt8->GetAsConstType()) ||
5132                        Type::Equal(type, AtomicType::VaryingUInt8->GetAsConstType()));
5133     for (int j = 0; j < Count(); ++j)
5134         uint8Val[j] = u[j];
5135 }
5136 
ConstExpr(const Type * t,int16_t i,SourcePos p)5137 ConstExpr::ConstExpr(const Type *t, int16_t i, SourcePos p) : Expr(p, ConstExprID) {
5138     type = t;
5139     type = type->GetAsConstType();
5140     AssertPos(pos, Type::Equal(type, AtomicType::UniformInt16->GetAsConstType()));
5141     int16Val[0] = i;
5142 }
5143 
ConstExpr(const Type * t,int16_t * i,SourcePos p)5144 ConstExpr::ConstExpr(const Type *t, int16_t *i, SourcePos p) : Expr(p, ConstExprID) {
5145     type = t;
5146     type = type->GetAsConstType();
5147     AssertPos(pos, Type::Equal(type, AtomicType::UniformInt16->GetAsConstType()) ||
5148                        Type::Equal(type, AtomicType::VaryingInt16->GetAsConstType()));
5149     for (int j = 0; j < Count(); ++j)
5150         int16Val[j] = i[j];
5151 }
5152 
ConstExpr(const Type * t,uint16_t u,SourcePos p)5153 ConstExpr::ConstExpr(const Type *t, uint16_t u, SourcePos p) : Expr(p, ConstExprID) {
5154     type = t;
5155     type = type->GetAsConstType();
5156     AssertPos(pos, Type::Equal(type, AtomicType::UniformUInt16->GetAsConstType()));
5157     uint16Val[0] = u;
5158 }
5159 
ConstExpr(const Type * t,uint16_t * u,SourcePos p)5160 ConstExpr::ConstExpr(const Type *t, uint16_t *u, SourcePos p) : Expr(p, ConstExprID) {
5161     type = t;
5162     type = type->GetAsConstType();
5163     AssertPos(pos, Type::Equal(type, AtomicType::UniformUInt16->GetAsConstType()) ||
5164                        Type::Equal(type, AtomicType::VaryingUInt16->GetAsConstType()));
5165     for (int j = 0; j < Count(); ++j)
5166         uint16Val[j] = u[j];
5167 }
5168 
ConstExpr(const Type * t,int32_t i,SourcePos p)5169 ConstExpr::ConstExpr(const Type *t, int32_t i, SourcePos p) : Expr(p, ConstExprID) {
5170     type = t;
5171     type = type->GetAsConstType();
5172     AssertPos(pos, Type::Equal(type, AtomicType::UniformInt32->GetAsConstType()));
5173     int32Val[0] = i;
5174 }
5175 
ConstExpr(const Type * t,int32_t * i,SourcePos p)5176 ConstExpr::ConstExpr(const Type *t, int32_t *i, SourcePos p) : Expr(p, ConstExprID) {
5177     type = t;
5178     type = type->GetAsConstType();
5179     AssertPos(pos, Type::Equal(type, AtomicType::UniformInt32->GetAsConstType()) ||
5180                        Type::Equal(type, AtomicType::VaryingInt32->GetAsConstType()));
5181     for (int j = 0; j < Count(); ++j)
5182         int32Val[j] = i[j];
5183 }
5184 
ConstExpr(const Type * t,uint32_t u,SourcePos p)5185 ConstExpr::ConstExpr(const Type *t, uint32_t u, SourcePos p) : Expr(p, ConstExprID) {
5186     type = t;
5187     type = type->GetAsConstType();
5188     AssertPos(pos, Type::Equal(type, AtomicType::UniformUInt32->GetAsConstType()) ||
5189                        (CastType<EnumType>(type) != NULL && type->IsUniformType()));
5190     uint32Val[0] = u;
5191 }
5192 
ConstExpr(const Type * t,uint32_t * u,SourcePos p)5193 ConstExpr::ConstExpr(const Type *t, uint32_t *u, SourcePos p) : Expr(p, ConstExprID) {
5194     type = t;
5195     type = type->GetAsConstType();
5196     AssertPos(pos, Type::Equal(type, AtomicType::UniformUInt32->GetAsConstType()) ||
5197                        Type::Equal(type, AtomicType::VaryingUInt32->GetAsConstType()) ||
5198                        (CastType<EnumType>(type) != NULL));
5199     for (int j = 0; j < Count(); ++j)
5200         uint32Val[j] = u[j];
5201 }
5202 
ConstExpr(const Type * t,float f,SourcePos p)5203 ConstExpr::ConstExpr(const Type *t, float f, SourcePos p) : Expr(p, ConstExprID) {
5204     type = t;
5205     type = type->GetAsConstType();
5206     AssertPos(pos, Type::Equal(type, AtomicType::UniformFloat->GetAsConstType()));
5207     floatVal[0] = f;
5208 }
5209 
ConstExpr(const Type * t,float * f,SourcePos p)5210 ConstExpr::ConstExpr(const Type *t, float *f, SourcePos p) : Expr(p, ConstExprID) {
5211     type = t;
5212     type = type->GetAsConstType();
5213     AssertPos(pos, Type::Equal(type, AtomicType::UniformFloat->GetAsConstType()) ||
5214                        Type::Equal(type, AtomicType::VaryingFloat->GetAsConstType()));
5215     for (int j = 0; j < Count(); ++j)
5216         floatVal[j] = f[j];
5217 }
5218 
ConstExpr(const Type * t,int64_t i,SourcePos p)5219 ConstExpr::ConstExpr(const Type *t, int64_t i, SourcePos p) : Expr(p, ConstExprID) {
5220     type = t;
5221     type = type->GetAsConstType();
5222     AssertPos(pos, Type::Equal(type, AtomicType::UniformInt64->GetAsConstType()));
5223     int64Val[0] = i;
5224 }
5225 
ConstExpr(const Type * t,int64_t * i,SourcePos p)5226 ConstExpr::ConstExpr(const Type *t, int64_t *i, SourcePos p) : Expr(p, ConstExprID) {
5227     type = t;
5228     type = type->GetAsConstType();
5229     AssertPos(pos, Type::Equal(type, AtomicType::UniformInt64->GetAsConstType()) ||
5230                        Type::Equal(type, AtomicType::VaryingInt64->GetAsConstType()));
5231     for (int j = 0; j < Count(); ++j)
5232         int64Val[j] = i[j];
5233 }
5234 
ConstExpr(const Type * t,uint64_t u,SourcePos p)5235 ConstExpr::ConstExpr(const Type *t, uint64_t u, SourcePos p) : Expr(p, ConstExprID) {
5236     type = t;
5237     type = type->GetAsConstType();
5238     AssertPos(pos, Type::Equal(type, AtomicType::UniformUInt64->GetAsConstType()));
5239     uint64Val[0] = u;
5240 }
5241 
ConstExpr(const Type * t,uint64_t * u,SourcePos p)5242 ConstExpr::ConstExpr(const Type *t, uint64_t *u, SourcePos p) : Expr(p, ConstExprID) {
5243     type = t;
5244     type = type->GetAsConstType();
5245     AssertPos(pos, Type::Equal(type, AtomicType::UniformUInt64->GetAsConstType()) ||
5246                        Type::Equal(type, AtomicType::VaryingUInt64->GetAsConstType()));
5247     for (int j = 0; j < Count(); ++j)
5248         uint64Val[j] = u[j];
5249 }
5250 
ConstExpr(const Type * t,double f,SourcePos p)5251 ConstExpr::ConstExpr(const Type *t, double f, SourcePos p) : Expr(p, ConstExprID) {
5252     type = t;
5253     type = type->GetAsConstType();
5254     AssertPos(pos, Type::Equal(type, AtomicType::UniformDouble->GetAsConstType()));
5255     doubleVal[0] = f;
5256 }
5257 
ConstExpr(const Type * t,double * f,SourcePos p)5258 ConstExpr::ConstExpr(const Type *t, double *f, SourcePos p) : Expr(p, ConstExprID) {
5259     type = t;
5260     type = type->GetAsConstType();
5261     AssertPos(pos, Type::Equal(type, AtomicType::UniformDouble->GetAsConstType()) ||
5262                        Type::Equal(type, AtomicType::VaryingDouble->GetAsConstType()));
5263     for (int j = 0; j < Count(); ++j)
5264         doubleVal[j] = f[j];
5265 }
5266 
ConstExpr(const Type * t,bool b,SourcePos p)5267 ConstExpr::ConstExpr(const Type *t, bool b, SourcePos p) : Expr(p, ConstExprID) {
5268     type = t;
5269     type = type->GetAsConstType();
5270     AssertPos(pos, Type::Equal(type, AtomicType::UniformBool->GetAsConstType()));
5271     boolVal[0] = b;
5272 }
5273 
ConstExpr(const Type * t,bool * b,SourcePos p)5274 ConstExpr::ConstExpr(const Type *t, bool *b, SourcePos p) : Expr(p, ConstExprID) {
5275     type = t;
5276     type = type->GetAsConstType();
5277     AssertPos(pos, Type::Equal(type, AtomicType::UniformBool->GetAsConstType()) ||
5278                        Type::Equal(type, AtomicType::VaryingBool->GetAsConstType()));
5279     for (int j = 0; j < Count(); ++j)
5280         boolVal[j] = b[j];
5281 }
5282 
ConstExpr(ConstExpr * old,double * v)5283 ConstExpr::ConstExpr(ConstExpr *old, double *v) : Expr(old->pos, ConstExprID) {
5284     type = old->type;
5285 
5286     AtomicType::BasicType basicType = getBasicType();
5287 
5288     switch (basicType) {
5289     case AtomicType::TYPE_BOOL:
5290         for (int i = 0; i < Count(); ++i)
5291             boolVal[i] = (v[i] != 0.);
5292         break;
5293     case AtomicType::TYPE_INT8:
5294         for (int i = 0; i < Count(); ++i)
5295             int8Val[i] = (int)v[i];
5296         break;
5297     case AtomicType::TYPE_UINT8:
5298         for (int i = 0; i < Count(); ++i)
5299             uint8Val[i] = (unsigned int)v[i];
5300         break;
5301     case AtomicType::TYPE_INT16:
5302         for (int i = 0; i < Count(); ++i)
5303             int16Val[i] = (int)v[i];
5304         break;
5305     case AtomicType::TYPE_UINT16:
5306         for (int i = 0; i < Count(); ++i)
5307             uint16Val[i] = (unsigned int)v[i];
5308         break;
5309     case AtomicType::TYPE_INT32:
5310         for (int i = 0; i < Count(); ++i)
5311             int32Val[i] = (int)v[i];
5312         break;
5313     case AtomicType::TYPE_UINT32:
5314         for (int i = 0; i < Count(); ++i)
5315             uint32Val[i] = (unsigned int)v[i];
5316         break;
5317     case AtomicType::TYPE_FLOAT:
5318         for (int i = 0; i < Count(); ++i)
5319             floatVal[i] = (float)v[i];
5320         break;
5321     case AtomicType::TYPE_DOUBLE:
5322         for (int i = 0; i < Count(); ++i)
5323             doubleVal[i] = v[i];
5324         break;
5325     case AtomicType::TYPE_INT64:
5326     case AtomicType::TYPE_UINT64:
5327         // For now, this should never be reached
5328         FATAL("fixme; we need another constructor so that we're not trying to pass "
5329               "double values to init an int64 type...");
5330     default:
5331         FATAL("unimplemented const type");
5332     }
5333 }
5334 
ConstExpr(ConstExpr * old,SourcePos p)5335 ConstExpr::ConstExpr(ConstExpr *old, SourcePos p) : Expr(p, ConstExprID) {
5336     type = old->type;
5337 
5338     AtomicType::BasicType basicType = getBasicType();
5339 
5340     switch (basicType) {
5341     case AtomicType::TYPE_BOOL:
5342         memcpy(boolVal, old->boolVal, Count() * sizeof(bool));
5343         break;
5344     case AtomicType::TYPE_INT8:
5345         memcpy(int8Val, old->int8Val, Count() * sizeof(int8_t));
5346         break;
5347     case AtomicType::TYPE_UINT8:
5348         memcpy(uint8Val, old->uint8Val, Count() * sizeof(uint8_t));
5349         break;
5350     case AtomicType::TYPE_INT16:
5351         memcpy(int16Val, old->int16Val, Count() * sizeof(int16_t));
5352         break;
5353     case AtomicType::TYPE_UINT16:
5354         memcpy(uint16Val, old->uint16Val, Count() * sizeof(uint16_t));
5355         break;
5356     case AtomicType::TYPE_INT32:
5357         memcpy(int32Val, old->int32Val, Count() * sizeof(int32_t));
5358         break;
5359     case AtomicType::TYPE_UINT32:
5360         memcpy(uint32Val, old->uint32Val, Count() * sizeof(uint32_t));
5361         break;
5362     case AtomicType::TYPE_FLOAT:
5363         memcpy(floatVal, old->floatVal, Count() * sizeof(float));
5364         break;
5365     case AtomicType::TYPE_DOUBLE:
5366         memcpy(doubleVal, old->doubleVal, Count() * sizeof(double));
5367         break;
5368     case AtomicType::TYPE_INT64:
5369         memcpy(int64Val, old->int64Val, Count() * sizeof(int64_t));
5370         break;
5371     case AtomicType::TYPE_UINT64:
5372         memcpy(uint64Val, old->uint64Val, Count() * sizeof(uint64_t));
5373         break;
5374     default:
5375         FATAL("unimplemented const type");
5376     }
5377 }
5378 
getBasicType() const5379 AtomicType::BasicType ConstExpr::getBasicType() const {
5380     const AtomicType *at = CastType<AtomicType>(type);
5381     if (at != NULL)
5382         return at->basicType;
5383     else {
5384         AssertPos(pos, CastType<EnumType>(type) != NULL);
5385         return AtomicType::TYPE_UINT32;
5386     }
5387 }
5388 
GetType() const5389 const Type *ConstExpr::GetType() const { return type; }
5390 
GetValue(FunctionEmitContext * ctx) const5391 llvm::Value *ConstExpr::GetValue(FunctionEmitContext *ctx) const {
5392     ctx->SetDebugPos(pos);
5393     bool isVarying = type->IsVaryingType();
5394 
5395     AtomicType::BasicType basicType = getBasicType();
5396 
5397     switch (basicType) {
5398     case AtomicType::TYPE_BOOL:
5399         if (isVarying)
5400             return LLVMBoolVector(boolVal);
5401         else
5402             return boolVal[0] ? LLVMTrue : LLVMFalse;
5403     case AtomicType::TYPE_INT8:
5404         return isVarying ? LLVMInt8Vector(int8Val) : LLVMInt8(int8Val[0]);
5405     case AtomicType::TYPE_UINT8:
5406         return isVarying ? LLVMUInt8Vector(uint8Val) : LLVMUInt8(uint8Val[0]);
5407     case AtomicType::TYPE_INT16:
5408         return isVarying ? LLVMInt16Vector(int16Val) : LLVMInt16(int16Val[0]);
5409     case AtomicType::TYPE_UINT16:
5410         return isVarying ? LLVMUInt16Vector(uint16Val) : LLVMUInt16(uint16Val[0]);
5411     case AtomicType::TYPE_INT32:
5412         return isVarying ? LLVMInt32Vector(int32Val) : LLVMInt32(int32Val[0]);
5413     case AtomicType::TYPE_UINT32:
5414         return isVarying ? LLVMUInt32Vector(uint32Val) : LLVMUInt32(uint32Val[0]);
5415     case AtomicType::TYPE_FLOAT:
5416         return isVarying ? LLVMFloatVector(floatVal) : LLVMFloat(floatVal[0]);
5417     case AtomicType::TYPE_INT64:
5418         return isVarying ? LLVMInt64Vector(int64Val) : LLVMInt64(int64Val[0]);
5419     case AtomicType::TYPE_UINT64:
5420         return isVarying ? LLVMUInt64Vector(uint64Val) : LLVMUInt64(uint64Val[0]);
5421     case AtomicType::TYPE_DOUBLE:
5422         return isVarying ? LLVMDoubleVector(doubleVal) : LLVMDouble(doubleVal[0]);
5423     default:
5424         FATAL("unimplemented const type");
5425         return NULL;
5426     }
5427 }
5428 
5429 /* Type conversion templates: take advantage of C++ function overloading
5430    rules to get the one we want to match. */
5431 
5432 /* First the most general case, just use C++ type conversion if nothing
5433    else matches */
lConvertElement(From from,To * to)5434 template <typename From, typename To> static inline void lConvertElement(From from, To *to) { *to = (To)from; }
5435 
5436 /** When converting from bool types to numeric types, make sure the result
5437     is one or zero.
5438  */
lConvertElement(bool from,To * to)5439 template <typename To> static inline void lConvertElement(bool from, To *to) { *to = from ? (To)1 : (To)0; }
5440 
5441 /** When converting numeric types to bool, compare to zero.  (Do we
5442     actually need this one??) */
lConvertElement(From from,bool * to)5443 template <typename From> static inline void lConvertElement(From from, bool *to) { *to = (from != 0); }
5444 
5445 /** And bool -> bool is just assignment */
lConvertElement(bool from,bool * to)5446 static inline void lConvertElement(bool from, bool *to) { *to = from; }
5447 
5448 /** Type conversion utility function
5449  */
lConvert(const From * from,To * to,int count,bool forceVarying)5450 template <typename From, typename To> static void lConvert(const From *from, To *to, int count, bool forceVarying) {
5451     for (int i = 0; i < count; ++i)
5452         lConvertElement(from[i], &to[i]);
5453 
5454     if (forceVarying && count == 1)
5455         for (int i = 1; i < g->target->getVectorWidth(); ++i)
5456             to[i] = to[0];
5457 }
5458 
GetValues(int64_t * ip,bool forceVarying) const5459 int ConstExpr::GetValues(int64_t *ip, bool forceVarying) const {
5460     switch (getBasicType()) {
5461     case AtomicType::TYPE_BOOL:
5462         lConvert(boolVal, ip, Count(), forceVarying);
5463         break;
5464     case AtomicType::TYPE_INT8:
5465         lConvert(int8Val, ip, Count(), forceVarying);
5466         break;
5467     case AtomicType::TYPE_UINT8:
5468         lConvert(uint8Val, ip, Count(), forceVarying);
5469         break;
5470     case AtomicType::TYPE_INT16:
5471         lConvert(int16Val, ip, Count(), forceVarying);
5472         break;
5473     case AtomicType::TYPE_UINT16:
5474         lConvert(uint16Val, ip, Count(), forceVarying);
5475         break;
5476     case AtomicType::TYPE_INT32:
5477         lConvert(int32Val, ip, Count(), forceVarying);
5478         break;
5479     case AtomicType::TYPE_UINT32:
5480         lConvert(uint32Val, ip, Count(), forceVarying);
5481         break;
5482     case AtomicType::TYPE_FLOAT:
5483         lConvert(floatVal, ip, Count(), forceVarying);
5484         break;
5485     case AtomicType::TYPE_DOUBLE:
5486         lConvert(doubleVal, ip, Count(), forceVarying);
5487         break;
5488     case AtomicType::TYPE_INT64:
5489         lConvert(int64Val, ip, Count(), forceVarying);
5490         break;
5491     case AtomicType::TYPE_UINT64:
5492         lConvert(uint64Val, ip, Count(), forceVarying);
5493         break;
5494     default:
5495         FATAL("unimplemented const type");
5496     }
5497     return Count();
5498 }
5499 
GetValues(uint64_t * up,bool forceVarying) const5500 int ConstExpr::GetValues(uint64_t *up, bool forceVarying) const {
5501     switch (getBasicType()) {
5502     case AtomicType::TYPE_BOOL:
5503         lConvert(boolVal, up, Count(), forceVarying);
5504         break;
5505     case AtomicType::TYPE_INT8:
5506         lConvert(int8Val, up, Count(), forceVarying);
5507         break;
5508     case AtomicType::TYPE_UINT8:
5509         lConvert(uint8Val, up, Count(), forceVarying);
5510         break;
5511     case AtomicType::TYPE_INT16:
5512         lConvert(int16Val, up, Count(), forceVarying);
5513         break;
5514     case AtomicType::TYPE_UINT16:
5515         lConvert(uint16Val, up, Count(), forceVarying);
5516         break;
5517     case AtomicType::TYPE_INT32:
5518         lConvert(int32Val, up, Count(), forceVarying);
5519         break;
5520     case AtomicType::TYPE_UINT32:
5521         lConvert(uint32Val, up, Count(), forceVarying);
5522         break;
5523     case AtomicType::TYPE_FLOAT:
5524         lConvert(floatVal, up, Count(), forceVarying);
5525         break;
5526     case AtomicType::TYPE_DOUBLE:
5527         lConvert(doubleVal, up, Count(), forceVarying);
5528         break;
5529     case AtomicType::TYPE_INT64:
5530         lConvert(int64Val, up, Count(), forceVarying);
5531         break;
5532     case AtomicType::TYPE_UINT64:
5533         lConvert(uint64Val, up, Count(), forceVarying);
5534         break;
5535     default:
5536         FATAL("unimplemented const type");
5537     }
5538     return Count();
5539 }
5540 
GetValues(double * d,bool forceVarying) const5541 int ConstExpr::GetValues(double *d, bool forceVarying) const {
5542     switch (getBasicType()) {
5543     case AtomicType::TYPE_BOOL:
5544         lConvert(boolVal, d, Count(), forceVarying);
5545         break;
5546     case AtomicType::TYPE_INT8:
5547         lConvert(int8Val, d, Count(), forceVarying);
5548         break;
5549     case AtomicType::TYPE_UINT8:
5550         lConvert(uint8Val, d, Count(), forceVarying);
5551         break;
5552     case AtomicType::TYPE_INT16:
5553         lConvert(int16Val, d, Count(), forceVarying);
5554         break;
5555     case AtomicType::TYPE_UINT16:
5556         lConvert(uint16Val, d, Count(), forceVarying);
5557         break;
5558     case AtomicType::TYPE_INT32:
5559         lConvert(int32Val, d, Count(), forceVarying);
5560         break;
5561     case AtomicType::TYPE_UINT32:
5562         lConvert(uint32Val, d, Count(), forceVarying);
5563         break;
5564     case AtomicType::TYPE_FLOAT:
5565         lConvert(floatVal, d, Count(), forceVarying);
5566         break;
5567     case AtomicType::TYPE_DOUBLE:
5568         lConvert(doubleVal, d, Count(), forceVarying);
5569         break;
5570     case AtomicType::TYPE_INT64:
5571         lConvert(int64Val, d, Count(), forceVarying);
5572         break;
5573     case AtomicType::TYPE_UINT64:
5574         lConvert(uint64Val, d, Count(), forceVarying);
5575         break;
5576     default:
5577         FATAL("unimplemented const type");
5578     }
5579     return Count();
5580 }
5581 
GetValues(float * fp,bool forceVarying) const5582 int ConstExpr::GetValues(float *fp, bool forceVarying) const {
5583     switch (getBasicType()) {
5584     case AtomicType::TYPE_BOOL:
5585         lConvert(boolVal, fp, Count(), forceVarying);
5586         break;
5587     case AtomicType::TYPE_INT8:
5588         lConvert(int8Val, fp, Count(), forceVarying);
5589         break;
5590     case AtomicType::TYPE_UINT8:
5591         lConvert(uint8Val, fp, Count(), forceVarying);
5592         break;
5593     case AtomicType::TYPE_INT16:
5594         lConvert(int16Val, fp, Count(), forceVarying);
5595         break;
5596     case AtomicType::TYPE_UINT16:
5597         lConvert(uint16Val, fp, Count(), forceVarying);
5598         break;
5599     case AtomicType::TYPE_INT32:
5600         lConvert(int32Val, fp, Count(), forceVarying);
5601         break;
5602     case AtomicType::TYPE_UINT32:
5603         lConvert(uint32Val, fp, Count(), forceVarying);
5604         break;
5605     case AtomicType::TYPE_FLOAT:
5606         lConvert(floatVal, fp, Count(), forceVarying);
5607         break;
5608     case AtomicType::TYPE_DOUBLE:
5609         lConvert(doubleVal, fp, Count(), forceVarying);
5610         break;
5611     case AtomicType::TYPE_INT64:
5612         lConvert(int64Val, fp, Count(), forceVarying);
5613         break;
5614     case AtomicType::TYPE_UINT64:
5615         lConvert(uint64Val, fp, Count(), forceVarying);
5616         break;
5617     default:
5618         FATAL("unimplemented const type");
5619     }
5620     return Count();
5621 }
5622 
GetValues(bool * b,bool forceVarying) const5623 int ConstExpr::GetValues(bool *b, bool forceVarying) const {
5624     switch (getBasicType()) {
5625     case AtomicType::TYPE_BOOL:
5626         lConvert(boolVal, b, Count(), forceVarying);
5627         break;
5628     case AtomicType::TYPE_INT8:
5629         lConvert(int8Val, b, Count(), forceVarying);
5630         break;
5631     case AtomicType::TYPE_UINT8:
5632         lConvert(uint8Val, b, Count(), forceVarying);
5633         break;
5634     case AtomicType::TYPE_INT16:
5635         lConvert(int16Val, b, Count(), forceVarying);
5636         break;
5637     case AtomicType::TYPE_UINT16:
5638         lConvert(uint16Val, b, Count(), forceVarying);
5639         break;
5640     case AtomicType::TYPE_INT32:
5641         lConvert(int32Val, b, Count(), forceVarying);
5642         break;
5643     case AtomicType::TYPE_UINT32:
5644         lConvert(uint32Val, b, Count(), forceVarying);
5645         break;
5646     case AtomicType::TYPE_FLOAT:
5647         lConvert(floatVal, b, Count(), forceVarying);
5648         break;
5649     case AtomicType::TYPE_DOUBLE:
5650         lConvert(doubleVal, b, Count(), forceVarying);
5651         break;
5652     case AtomicType::TYPE_INT64:
5653         lConvert(int64Val, b, Count(), forceVarying);
5654         break;
5655     case AtomicType::TYPE_UINT64:
5656         lConvert(uint64Val, b, Count(), forceVarying);
5657         break;
5658     default:
5659         FATAL("unimplemented const type");
5660     }
5661     return Count();
5662 }
5663 
GetValues(int8_t * ip,bool forceVarying) const5664 int ConstExpr::GetValues(int8_t *ip, bool forceVarying) const {
5665     switch (getBasicType()) {
5666     case AtomicType::TYPE_BOOL:
5667         lConvert(boolVal, ip, Count(), forceVarying);
5668         break;
5669     case AtomicType::TYPE_INT8:
5670         lConvert(int8Val, ip, Count(), forceVarying);
5671         break;
5672     case AtomicType::TYPE_UINT8:
5673         lConvert(uint8Val, ip, Count(), forceVarying);
5674         break;
5675     case AtomicType::TYPE_INT16:
5676         lConvert(int16Val, ip, Count(), forceVarying);
5677         break;
5678     case AtomicType::TYPE_UINT16:
5679         lConvert(uint16Val, ip, Count(), forceVarying);
5680         break;
5681     case AtomicType::TYPE_INT32:
5682         lConvert(int32Val, ip, Count(), forceVarying);
5683         break;
5684     case AtomicType::TYPE_UINT32:
5685         lConvert(uint32Val, ip, Count(), forceVarying);
5686         break;
5687     case AtomicType::TYPE_FLOAT:
5688         lConvert(floatVal, ip, Count(), forceVarying);
5689         break;
5690     case AtomicType::TYPE_DOUBLE:
5691         lConvert(doubleVal, ip, Count(), forceVarying);
5692         break;
5693     case AtomicType::TYPE_INT64:
5694         lConvert(int64Val, ip, Count(), forceVarying);
5695         break;
5696     case AtomicType::TYPE_UINT64:
5697         lConvert(uint64Val, ip, Count(), forceVarying);
5698         break;
5699     default:
5700         FATAL("unimplemented const type");
5701     }
5702     return Count();
5703 }
5704 
GetValues(uint8_t * up,bool forceVarying) const5705 int ConstExpr::GetValues(uint8_t *up, bool forceVarying) const {
5706     switch (getBasicType()) {
5707     case AtomicType::TYPE_BOOL:
5708         lConvert(boolVal, up, Count(), forceVarying);
5709         break;
5710     case AtomicType::TYPE_INT8:
5711         lConvert(int8Val, up, Count(), forceVarying);
5712         break;
5713     case AtomicType::TYPE_UINT8:
5714         lConvert(uint8Val, up, Count(), forceVarying);
5715         break;
5716     case AtomicType::TYPE_INT16:
5717         lConvert(int16Val, up, Count(), forceVarying);
5718         break;
5719     case AtomicType::TYPE_UINT16:
5720         lConvert(uint16Val, up, Count(), forceVarying);
5721         break;
5722     case AtomicType::TYPE_INT32:
5723         lConvert(int32Val, up, Count(), forceVarying);
5724         break;
5725     case AtomicType::TYPE_UINT32:
5726         lConvert(uint32Val, up, Count(), forceVarying);
5727         break;
5728     case AtomicType::TYPE_FLOAT:
5729         lConvert(floatVal, up, Count(), forceVarying);
5730         break;
5731     case AtomicType::TYPE_DOUBLE:
5732         lConvert(doubleVal, up, Count(), forceVarying);
5733         break;
5734     case AtomicType::TYPE_INT64:
5735         lConvert(int64Val, up, Count(), forceVarying);
5736         break;
5737     case AtomicType::TYPE_UINT64:
5738         lConvert(uint64Val, up, Count(), forceVarying);
5739         break;
5740     default:
5741         FATAL("unimplemented const type");
5742     }
5743     return Count();
5744 }
5745 
GetValues(int16_t * ip,bool forceVarying) const5746 int ConstExpr::GetValues(int16_t *ip, bool forceVarying) const {
5747     switch (getBasicType()) {
5748     case AtomicType::TYPE_BOOL:
5749         lConvert(boolVal, ip, Count(), forceVarying);
5750         break;
5751     case AtomicType::TYPE_INT8:
5752         lConvert(int8Val, ip, Count(), forceVarying);
5753         break;
5754     case AtomicType::TYPE_UINT8:
5755         lConvert(uint8Val, ip, Count(), forceVarying);
5756         break;
5757     case AtomicType::TYPE_INT16:
5758         lConvert(int16Val, ip, Count(), forceVarying);
5759         break;
5760     case AtomicType::TYPE_UINT16:
5761         lConvert(uint16Val, ip, Count(), forceVarying);
5762         break;
5763     case AtomicType::TYPE_INT32:
5764         lConvert(int32Val, ip, Count(), forceVarying);
5765         break;
5766     case AtomicType::TYPE_UINT32:
5767         lConvert(uint32Val, ip, Count(), forceVarying);
5768         break;
5769     case AtomicType::TYPE_FLOAT:
5770         lConvert(floatVal, ip, Count(), forceVarying);
5771         break;
5772     case AtomicType::TYPE_DOUBLE:
5773         lConvert(doubleVal, ip, Count(), forceVarying);
5774         break;
5775     case AtomicType::TYPE_INT64:
5776         lConvert(int64Val, ip, Count(), forceVarying);
5777         break;
5778     case AtomicType::TYPE_UINT64:
5779         lConvert(uint64Val, ip, Count(), forceVarying);
5780         break;
5781     default:
5782         FATAL("unimplemented const type");
5783     }
5784     return Count();
5785 }
5786 
GetValues(uint16_t * up,bool forceVarying) const5787 int ConstExpr::GetValues(uint16_t *up, bool forceVarying) const {
5788     switch (getBasicType()) {
5789     case AtomicType::TYPE_BOOL:
5790         lConvert(boolVal, up, Count(), forceVarying);
5791         break;
5792     case AtomicType::TYPE_INT8:
5793         lConvert(int8Val, up, Count(), forceVarying);
5794         break;
5795     case AtomicType::TYPE_UINT8:
5796         lConvert(uint8Val, up, Count(), forceVarying);
5797         break;
5798     case AtomicType::TYPE_INT16:
5799         lConvert(int16Val, up, Count(), forceVarying);
5800         break;
5801     case AtomicType::TYPE_UINT16:
5802         lConvert(uint16Val, up, Count(), forceVarying);
5803         break;
5804     case AtomicType::TYPE_INT32:
5805         lConvert(int32Val, up, Count(), forceVarying);
5806         break;
5807     case AtomicType::TYPE_UINT32:
5808         lConvert(uint32Val, up, Count(), forceVarying);
5809         break;
5810     case AtomicType::TYPE_FLOAT:
5811         lConvert(floatVal, up, Count(), forceVarying);
5812         break;
5813     case AtomicType::TYPE_DOUBLE:
5814         lConvert(doubleVal, up, Count(), forceVarying);
5815         break;
5816     case AtomicType::TYPE_INT64:
5817         lConvert(int64Val, up, Count(), forceVarying);
5818         break;
5819     case AtomicType::TYPE_UINT64:
5820         lConvert(uint64Val, up, Count(), forceVarying);
5821         break;
5822     default:
5823         FATAL("unimplemented const type");
5824     }
5825     return Count();
5826 }
5827 
GetValues(int32_t * ip,bool forceVarying) const5828 int ConstExpr::GetValues(int32_t *ip, bool forceVarying) const {
5829     switch (getBasicType()) {
5830     case AtomicType::TYPE_BOOL:
5831         lConvert(boolVal, ip, Count(), forceVarying);
5832         break;
5833     case AtomicType::TYPE_INT8:
5834         lConvert(int8Val, ip, Count(), forceVarying);
5835         break;
5836     case AtomicType::TYPE_UINT8:
5837         lConvert(uint8Val, ip, Count(), forceVarying);
5838         break;
5839     case AtomicType::TYPE_INT16:
5840         lConvert(int16Val, ip, Count(), forceVarying);
5841         break;
5842     case AtomicType::TYPE_UINT16:
5843         lConvert(uint16Val, ip, Count(), forceVarying);
5844         break;
5845     case AtomicType::TYPE_INT32:
5846         lConvert(int32Val, ip, Count(), forceVarying);
5847         break;
5848     case AtomicType::TYPE_UINT32:
5849         lConvert(uint32Val, ip, Count(), forceVarying);
5850         break;
5851     case AtomicType::TYPE_FLOAT:
5852         lConvert(floatVal, ip, Count(), forceVarying);
5853         break;
5854     case AtomicType::TYPE_DOUBLE:
5855         lConvert(doubleVal, ip, Count(), forceVarying);
5856         break;
5857     case AtomicType::TYPE_INT64:
5858         lConvert(int64Val, ip, Count(), forceVarying);
5859         break;
5860     case AtomicType::TYPE_UINT64:
5861         lConvert(uint64Val, ip, Count(), forceVarying);
5862         break;
5863     default:
5864         FATAL("unimplemented const type");
5865     }
5866     return Count();
5867 }
5868 
GetValues(uint32_t * up,bool forceVarying) const5869 int ConstExpr::GetValues(uint32_t *up, bool forceVarying) const {
5870     switch (getBasicType()) {
5871     case AtomicType::TYPE_BOOL:
5872         lConvert(boolVal, up, Count(), forceVarying);
5873         break;
5874     case AtomicType::TYPE_INT8:
5875         lConvert(int8Val, up, Count(), forceVarying);
5876         break;
5877     case AtomicType::TYPE_UINT8:
5878         lConvert(uint8Val, up, Count(), forceVarying);
5879         break;
5880     case AtomicType::TYPE_INT16:
5881         lConvert(int16Val, up, Count(), forceVarying);
5882         break;
5883     case AtomicType::TYPE_UINT16:
5884         lConvert(uint16Val, up, Count(), forceVarying);
5885         break;
5886     case AtomicType::TYPE_INT32:
5887         lConvert(int32Val, up, Count(), forceVarying);
5888         break;
5889     case AtomicType::TYPE_UINT32:
5890         lConvert(uint32Val, up, Count(), forceVarying);
5891         break;
5892     case AtomicType::TYPE_FLOAT:
5893         lConvert(floatVal, up, Count(), forceVarying);
5894         break;
5895     case AtomicType::TYPE_DOUBLE:
5896         lConvert(doubleVal, up, Count(), forceVarying);
5897         break;
5898     case AtomicType::TYPE_INT64:
5899         lConvert(int64Val, up, Count(), forceVarying);
5900         break;
5901     case AtomicType::TYPE_UINT64:
5902         lConvert(uint64Val, up, Count(), forceVarying);
5903         break;
5904     default:
5905         FATAL("unimplemented const type");
5906     }
5907     return Count();
5908 }
5909 
Count() const5910 int ConstExpr::Count() const { return GetType()->IsVaryingType() ? g->target->getVectorWidth() : 1; }
5911 
lGetConstExprConstant(const Type * constType,const ConstExpr * cExpr,bool isStorageType)5912 static std::pair<llvm::Constant *, bool> lGetConstExprConstant(const Type *constType, const ConstExpr *cExpr,
5913                                                                bool isStorageType) {
5914     // Caller shouldn't be trying to stuff a varying value here into a
5915     // constant type.
5916     SourcePos pos = cExpr->pos;
5917     bool isNotValidForMultiTargetGlobal = false;
5918     if (constType->IsUniformType())
5919         AssertPos(pos, cExpr->Count() == 1);
5920 
5921     constType = constType->GetAsNonConstType();
5922     if (Type::Equal(constType, AtomicType::UniformBool) || Type::Equal(constType, AtomicType::VaryingBool)) {
5923         bool bv[ISPC_MAX_NVEC];
5924         cExpr->GetValues(bv, constType->IsVaryingType());
5925         if (constType->IsUniformType()) {
5926             if (isStorageType)
5927                 return std::pair<llvm::Constant *, bool>(bv[0] ? LLVMTrueInStorage : LLVMFalseInStorage,
5928                                                          isNotValidForMultiTargetGlobal);
5929             else
5930                 return std::pair<llvm::Constant *, bool>(bv[0] ? LLVMTrue : LLVMFalse, isNotValidForMultiTargetGlobal);
5931         } else {
5932             if (isStorageType)
5933                 return std::pair<llvm::Constant *, bool>(LLVMBoolVectorInStorage(bv), isNotValidForMultiTargetGlobal);
5934             else
5935                 return std::pair<llvm::Constant *, bool>(LLVMBoolVector(bv), isNotValidForMultiTargetGlobal);
5936         }
5937     } else if (Type::Equal(constType, AtomicType::UniformInt8) || Type::Equal(constType, AtomicType::VaryingInt8)) {
5938         int8_t iv[ISPC_MAX_NVEC];
5939         cExpr->GetValues(iv, constType->IsVaryingType());
5940         if (constType->IsUniformType())
5941             return std::pair<llvm::Constant *, bool>(LLVMInt8(iv[0]), isNotValidForMultiTargetGlobal);
5942         else
5943             return std::pair<llvm::Constant *, bool>(LLVMInt8Vector(iv), isNotValidForMultiTargetGlobal);
5944     } else if (Type::Equal(constType, AtomicType::UniformUInt8) || Type::Equal(constType, AtomicType::VaryingUInt8)) {
5945         uint8_t uiv[ISPC_MAX_NVEC];
5946         cExpr->GetValues(uiv, constType->IsVaryingType());
5947         if (constType->IsUniformType())
5948             return std::pair<llvm::Constant *, bool>(LLVMUInt8(uiv[0]), isNotValidForMultiTargetGlobal);
5949         else
5950             return std::pair<llvm::Constant *, bool>(LLVMUInt8Vector(uiv), isNotValidForMultiTargetGlobal);
5951     } else if (Type::Equal(constType, AtomicType::UniformInt16) || Type::Equal(constType, AtomicType::VaryingInt16)) {
5952         int16_t iv[ISPC_MAX_NVEC];
5953         cExpr->GetValues(iv, constType->IsVaryingType());
5954         if (constType->IsUniformType())
5955             return std::pair<llvm::Constant *, bool>(LLVMInt16(iv[0]), isNotValidForMultiTargetGlobal);
5956         else
5957             return std::pair<llvm::Constant *, bool>(LLVMInt16Vector(iv), isNotValidForMultiTargetGlobal);
5958     } else if (Type::Equal(constType, AtomicType::UniformUInt16) || Type::Equal(constType, AtomicType::VaryingUInt16)) {
5959         uint16_t uiv[ISPC_MAX_NVEC];
5960         cExpr->GetValues(uiv, constType->IsVaryingType());
5961         if (constType->IsUniformType())
5962             return std::pair<llvm::Constant *, bool>(LLVMUInt16(uiv[0]), isNotValidForMultiTargetGlobal);
5963         else
5964             return std::pair<llvm::Constant *, bool>(LLVMUInt16Vector(uiv), isNotValidForMultiTargetGlobal);
5965     } else if (Type::Equal(constType, AtomicType::UniformInt32) || Type::Equal(constType, AtomicType::VaryingInt32)) {
5966         int32_t iv[ISPC_MAX_NVEC];
5967         cExpr->GetValues(iv, constType->IsVaryingType());
5968         if (constType->IsUniformType())
5969             return std::pair<llvm::Constant *, bool>(LLVMInt32(iv[0]), isNotValidForMultiTargetGlobal);
5970         else
5971             return std::pair<llvm::Constant *, bool>(LLVMInt32Vector(iv), isNotValidForMultiTargetGlobal);
5972     } else if (Type::Equal(constType, AtomicType::UniformUInt32) || Type::Equal(constType, AtomicType::VaryingUInt32) ||
5973                CastType<EnumType>(constType) != NULL) {
5974         uint32_t uiv[ISPC_MAX_NVEC];
5975         cExpr->GetValues(uiv, constType->IsVaryingType());
5976         if (constType->IsUniformType())
5977             return std::pair<llvm::Constant *, bool>(LLVMUInt32(uiv[0]), isNotValidForMultiTargetGlobal);
5978         else
5979             return std::pair<llvm::Constant *, bool>(LLVMUInt32Vector(uiv), isNotValidForMultiTargetGlobal);
5980     } else if (Type::Equal(constType, AtomicType::UniformFloat) || Type::Equal(constType, AtomicType::VaryingFloat)) {
5981         float fv[ISPC_MAX_NVEC];
5982         cExpr->GetValues(fv, constType->IsVaryingType());
5983         if (constType->IsUniformType())
5984             return std::pair<llvm::Constant *, bool>(LLVMFloat(fv[0]), isNotValidForMultiTargetGlobal);
5985         else
5986             return std::pair<llvm::Constant *, bool>(LLVMFloatVector(fv), isNotValidForMultiTargetGlobal);
5987     } else if (Type::Equal(constType, AtomicType::UniformInt64) || Type::Equal(constType, AtomicType::VaryingInt64)) {
5988         int64_t iv[ISPC_MAX_NVEC];
5989         cExpr->GetValues(iv, constType->IsVaryingType());
5990         if (constType->IsUniformType())
5991             return std::pair<llvm::Constant *, bool>(LLVMInt64(iv[0]), isNotValidForMultiTargetGlobal);
5992         else
5993             return std::pair<llvm::Constant *, bool>(LLVMInt64Vector(iv), isNotValidForMultiTargetGlobal);
5994     } else if (Type::Equal(constType, AtomicType::UniformUInt64) || Type::Equal(constType, AtomicType::VaryingUInt64)) {
5995         uint64_t uiv[ISPC_MAX_NVEC];
5996         cExpr->GetValues(uiv, constType->IsVaryingType());
5997         if (constType->IsUniformType())
5998             return std::pair<llvm::Constant *, bool>(LLVMUInt64(uiv[0]), isNotValidForMultiTargetGlobal);
5999         else
6000             return std::pair<llvm::Constant *, bool>(LLVMUInt64Vector(uiv), isNotValidForMultiTargetGlobal);
6001     } else if (Type::Equal(constType, AtomicType::UniformDouble) || Type::Equal(constType, AtomicType::VaryingDouble)) {
6002         double dv[ISPC_MAX_NVEC];
6003         cExpr->GetValues(dv, constType->IsVaryingType());
6004         if (constType->IsUniformType())
6005             return std::pair<llvm::Constant *, bool>(LLVMDouble(dv[0]), isNotValidForMultiTargetGlobal);
6006         else
6007             return std::pair<llvm::Constant *, bool>(LLVMDoubleVector(dv), isNotValidForMultiTargetGlobal);
6008     } else if (CastType<PointerType>(constType) != NULL) {
6009         // The only time we should get here is if we have an integer '0'
6010         // constant that should be turned into a NULL pointer of the
6011         // appropriate type.
6012         llvm::Type *llvmType = constType->LLVMType(g->ctx);
6013         if (llvmType == NULL) {
6014             AssertPos(pos, m->errorCount > 0);
6015             return std::pair<llvm::Constant *, bool>(NULL, false);
6016         }
6017 
6018         int64_t iv[ISPC_MAX_NVEC];
6019         cExpr->GetValues(iv, constType->IsVaryingType());
6020         for (int i = 0; i < cExpr->Count(); ++i)
6021             if (iv[i] != 0)
6022                 // We'll issue an error about this later--trying to assign
6023                 // a constant int to a pointer, without a typecast.
6024                 return std::pair<llvm::Constant *, bool>(NULL, false);
6025 
6026         return std::pair<llvm::Constant *, bool>(llvm::Constant::getNullValue(llvmType),
6027                                                  isNotValidForMultiTargetGlobal);
6028     } else {
6029         Debug(pos, "Unable to handle type \"%s\" in ConstExpr::GetConstant().", constType->GetString().c_str());
6030         return std::pair<llvm::Constant *, bool>(NULL, isNotValidForMultiTargetGlobal);
6031     }
6032 }
6033 
GetStorageConstant(const Type * constType) const6034 std::pair<llvm::Constant *, bool> ConstExpr::GetStorageConstant(const Type *constType) const {
6035     return lGetConstExprConstant(constType, this, true);
6036 }
GetConstant(const Type * constType) const6037 std::pair<llvm::Constant *, bool> ConstExpr::GetConstant(const Type *constType) const {
6038     return lGetConstExprConstant(constType, this, false);
6039 }
6040 
Optimize()6041 Expr *ConstExpr::Optimize() { return this; }
6042 
TypeCheck()6043 Expr *ConstExpr::TypeCheck() { return this; }
6044 
EstimateCost() const6045 int ConstExpr::EstimateCost() const { return 0; }
6046 
Print() const6047 void ConstExpr::Print() const {
6048     printf("[%s] (", GetType()->GetString().c_str());
6049     for (int i = 0; i < Count(); ++i) {
6050         switch (getBasicType()) {
6051         case AtomicType::TYPE_BOOL:
6052             printf("%s", boolVal[i] ? "true" : "false");
6053             break;
6054         case AtomicType::TYPE_INT8:
6055             printf("%d", (int)int8Val[i]);
6056             break;
6057         case AtomicType::TYPE_UINT8:
6058             printf("%u", (int)uint8Val[i]);
6059             break;
6060         case AtomicType::TYPE_INT16:
6061             printf("%d", (int)int16Val[i]);
6062             break;
6063         case AtomicType::TYPE_UINT16:
6064             printf("%u", (int)uint16Val[i]);
6065             break;
6066         case AtomicType::TYPE_INT32:
6067             printf("%d", int32Val[i]);
6068             break;
6069         case AtomicType::TYPE_UINT32:
6070             printf("%u", uint32Val[i]);
6071             break;
6072         case AtomicType::TYPE_FLOAT:
6073             printf("%f", floatVal[i]);
6074             break;
6075         case AtomicType::TYPE_INT64:
6076             printf("%" PRId64, int64Val[i]);
6077             break;
6078         case AtomicType::TYPE_UINT64:
6079             printf("%" PRIu64, uint64Val[i]);
6080             break;
6081         case AtomicType::TYPE_DOUBLE:
6082             printf("%f", doubleVal[i]);
6083             break;
6084         default:
6085             FATAL("unimplemented const type");
6086         }
6087         if (i != Count() - 1)
6088             printf(", ");
6089     }
6090     printf(")");
6091     pos.Print();
6092 }
6093 
6094 ///////////////////////////////////////////////////////////////////////////
6095 // TypeCastExpr
6096 
TypeCastExpr(const Type * t,Expr * e,SourcePos p)6097 TypeCastExpr::TypeCastExpr(const Type *t, Expr *e, SourcePos p) : Expr(p, TypeCastExprID) {
6098     type = t;
6099     expr = e;
6100 }
6101 
6102 /** Handle all the grungy details of type conversion between atomic types.
6103     Given an input value in exprVal of type fromType, convert it to the
6104     llvm::Value with type toType.
6105  */
lTypeConvAtomic(FunctionEmitContext * ctx,llvm::Value * exprVal,const AtomicType * toType,const AtomicType * fromType,SourcePos pos)6106 static llvm::Value *lTypeConvAtomic(FunctionEmitContext *ctx, llvm::Value *exprVal, const AtomicType *toType,
6107                                     const AtomicType *fromType, SourcePos pos) {
6108     llvm::Value *cast = NULL;
6109 
6110     std::string opName = exprVal->getName().str();
6111     switch (toType->basicType) {
6112     case AtomicType::TYPE_BOOL:
6113         opName += "_to_bool";
6114         break;
6115     case AtomicType::TYPE_INT8:
6116         opName += "_to_int8";
6117         break;
6118     case AtomicType::TYPE_UINT8:
6119         opName += "_to_uint8";
6120         break;
6121     case AtomicType::TYPE_INT16:
6122         opName += "_to_int16";
6123         break;
6124     case AtomicType::TYPE_UINT16:
6125         opName += "_to_uint16";
6126         break;
6127     case AtomicType::TYPE_INT32:
6128         opName += "_to_int32";
6129         break;
6130     case AtomicType::TYPE_UINT32:
6131         opName += "_to_uint32";
6132         break;
6133     case AtomicType::TYPE_INT64:
6134         opName += "_to_int64";
6135         break;
6136     case AtomicType::TYPE_UINT64:
6137         opName += "_to_uint64";
6138         break;
6139     case AtomicType::TYPE_FLOAT:
6140         opName += "_to_float";
6141         break;
6142     case AtomicType::TYPE_DOUBLE:
6143         opName += "_to_double";
6144         break;
6145     default:
6146         FATAL("Unimplemented");
6147     }
6148     const char *cOpName = opName.c_str();
6149 
6150     switch (toType->basicType) {
6151     case AtomicType::TYPE_FLOAT: {
6152         llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::FloatType : LLVMTypes::FloatVectorType;
6153         switch (fromType->basicType) {
6154         case AtomicType::TYPE_BOOL:
6155             if (fromType->IsVaryingType())
6156                 // If we have a bool vector of non-i1 elements, first
6157                 // truncate down to a single bit.
6158                 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6159             // And then do an unisgned int->float cast
6160             cast = ctx->CastInst(llvm::Instruction::UIToFP, // unsigned int
6161                                  exprVal, targetType, cOpName);
6162             break;
6163         case AtomicType::TYPE_INT8:
6164         case AtomicType::TYPE_INT16:
6165         case AtomicType::TYPE_INT32:
6166         case AtomicType::TYPE_INT64:
6167             cast = ctx->CastInst(llvm::Instruction::SIToFP, // signed int to float
6168                                  exprVal, targetType, cOpName);
6169             break;
6170         case AtomicType::TYPE_UINT8:
6171         case AtomicType::TYPE_UINT16:
6172         case AtomicType::TYPE_UINT32:
6173         case AtomicType::TYPE_UINT64:
6174             // float -> uint32 is the only conversion for which a signed cvt
6175             // exists which cannot be used for unsigned.
6176             // This is a problem for non-neon, non-avx512 targets from among
6177             // arm/x86 cpu targets. Revisit for genx/wasm.
6178             if (fromType->IsVaryingType() && (g->target->warnFtoU32IsExpensive() == true) &&
6179                 (fromType->basicType == AtomicType::TYPE_UINT32))
6180                 PerformanceWarning(pos, "Conversion from unsigned int to float is slow. "
6181                                         "Use \"int\" if possible");
6182             cast = ctx->CastInst(llvm::Instruction::UIToFP, // unsigned int to float
6183                                  exprVal, targetType, cOpName);
6184             break;
6185         case AtomicType::TYPE_FLOAT:
6186             // No-op cast.
6187             cast = exprVal;
6188             break;
6189         case AtomicType::TYPE_DOUBLE:
6190             cast = ctx->FPCastInst(exprVal, targetType, cOpName);
6191             break;
6192         default:
6193             FATAL("unimplemented");
6194         }
6195         break;
6196     }
6197     case AtomicType::TYPE_DOUBLE: {
6198         llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::DoubleType : LLVMTypes::DoubleVectorType;
6199         switch (fromType->basicType) {
6200         case AtomicType::TYPE_BOOL:
6201             if (fromType->IsVaryingType())
6202                 // truncate bool vector values to i1s if necessary.
6203                 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6204             cast = ctx->CastInst(llvm::Instruction::UIToFP, // unsigned int to double
6205                                  exprVal, targetType, cOpName);
6206             break;
6207         case AtomicType::TYPE_INT8:
6208         case AtomicType::TYPE_INT16:
6209         case AtomicType::TYPE_INT32:
6210         case AtomicType::TYPE_INT64:
6211             cast = ctx->CastInst(llvm::Instruction::SIToFP, // signed int
6212                                  exprVal, targetType, cOpName);
6213             break;
6214         case AtomicType::TYPE_UINT8:
6215         case AtomicType::TYPE_UINT16:
6216         case AtomicType::TYPE_UINT32:
6217         case AtomicType::TYPE_UINT64:
6218             cast = ctx->CastInst(llvm::Instruction::UIToFP, // unsigned int
6219                                  exprVal, targetType, cOpName);
6220             break;
6221         case AtomicType::TYPE_FLOAT:
6222             cast = ctx->FPCastInst(exprVal, targetType, cOpName);
6223             break;
6224         case AtomicType::TYPE_DOUBLE:
6225             cast = exprVal;
6226             break;
6227         default:
6228             FATAL("unimplemented");
6229         }
6230         break;
6231     }
6232     case AtomicType::TYPE_INT8: {
6233         llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::Int8Type : LLVMTypes::Int8VectorType;
6234         switch (fromType->basicType) {
6235         case AtomicType::TYPE_BOOL:
6236             if (fromType->IsVaryingType())
6237                 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6238             cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6239             break;
6240         case AtomicType::TYPE_INT8:
6241         case AtomicType::TYPE_UINT8:
6242             cast = exprVal;
6243             break;
6244         case AtomicType::TYPE_INT16:
6245         case AtomicType::TYPE_UINT16:
6246         case AtomicType::TYPE_INT32:
6247         case AtomicType::TYPE_UINT32:
6248         case AtomicType::TYPE_INT64:
6249         case AtomicType::TYPE_UINT64:
6250             cast = ctx->TruncInst(exprVal, targetType, cOpName);
6251             break;
6252         case AtomicType::TYPE_FLOAT:
6253         case AtomicType::TYPE_DOUBLE:
6254             cast = ctx->CastInst(llvm::Instruction::FPToSI, // signed int
6255                                  exprVal, targetType, cOpName);
6256             break;
6257         default:
6258             FATAL("unimplemented");
6259         }
6260         break;
6261     }
6262     case AtomicType::TYPE_UINT8: {
6263         llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::Int8Type : LLVMTypes::Int8VectorType;
6264         switch (fromType->basicType) {
6265         case AtomicType::TYPE_BOOL:
6266             if (fromType->IsVaryingType())
6267                 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6268             cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6269             break;
6270         case AtomicType::TYPE_INT8:
6271         case AtomicType::TYPE_UINT8:
6272             cast = exprVal;
6273             break;
6274         case AtomicType::TYPE_INT16:
6275         case AtomicType::TYPE_UINT16:
6276         case AtomicType::TYPE_INT32:
6277         case AtomicType::TYPE_UINT32:
6278         case AtomicType::TYPE_INT64:
6279         case AtomicType::TYPE_UINT64:
6280             cast = ctx->TruncInst(exprVal, targetType, cOpName);
6281             break;
6282         case AtomicType::TYPE_FLOAT:
6283             if (fromType->IsVaryingType())
6284                 PerformanceWarning(pos, "Conversion from float to unsigned int is slow. "
6285                                         "Use \"int\" if possible");
6286             cast = ctx->CastInst(llvm::Instruction::FPToUI, // unsigned int
6287                                  exprVal, targetType, cOpName);
6288             break;
6289         case AtomicType::TYPE_DOUBLE:
6290             if (fromType->IsVaryingType())
6291                 PerformanceWarning(pos, "Conversion from double to unsigned int is slow. "
6292                                         "Use \"int\" if possible");
6293             cast = ctx->CastInst(llvm::Instruction::FPToUI, // unsigned int
6294                                  exprVal, targetType, cOpName);
6295             break;
6296         default:
6297             FATAL("unimplemented");
6298         }
6299         break;
6300     }
6301     case AtomicType::TYPE_INT16: {
6302         llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::Int16Type : LLVMTypes::Int16VectorType;
6303         switch (fromType->basicType) {
6304         case AtomicType::TYPE_BOOL:
6305             if (fromType->IsVaryingType())
6306                 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6307             cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6308             break;
6309         case AtomicType::TYPE_INT8:
6310             cast = ctx->SExtInst(exprVal, targetType, cOpName);
6311             break;
6312         case AtomicType::TYPE_UINT8:
6313             cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6314             break;
6315         case AtomicType::TYPE_INT16:
6316         case AtomicType::TYPE_UINT16:
6317             cast = exprVal;
6318             break;
6319         case AtomicType::TYPE_INT32:
6320         case AtomicType::TYPE_UINT32:
6321         case AtomicType::TYPE_INT64:
6322         case AtomicType::TYPE_UINT64:
6323             cast = ctx->TruncInst(exprVal, targetType, cOpName);
6324             break;
6325         case AtomicType::TYPE_FLOAT:
6326         case AtomicType::TYPE_DOUBLE:
6327             cast = ctx->CastInst(llvm::Instruction::FPToSI, // signed int
6328                                  exprVal, targetType, cOpName);
6329             break;
6330         default:
6331             FATAL("unimplemented");
6332         }
6333         break;
6334     }
6335     case AtomicType::TYPE_UINT16: {
6336         llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::Int16Type : LLVMTypes::Int16VectorType;
6337         switch (fromType->basicType) {
6338         case AtomicType::TYPE_BOOL:
6339             if (fromType->IsVaryingType())
6340                 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6341             cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6342             break;
6343         case AtomicType::TYPE_INT8:
6344             cast = ctx->SExtInst(exprVal, targetType, cOpName);
6345             break;
6346         case AtomicType::TYPE_UINT8:
6347             cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6348             break;
6349         case AtomicType::TYPE_INT16:
6350         case AtomicType::TYPE_UINT16:
6351             cast = exprVal;
6352             break;
6353         case AtomicType::TYPE_FLOAT:
6354             if (fromType->IsVaryingType())
6355                 PerformanceWarning(pos, "Conversion from float to unsigned int is slow. "
6356                                         "Use \"int\" if possible");
6357             cast = ctx->CastInst(llvm::Instruction::FPToUI, // unsigned int
6358                                  exprVal, targetType, cOpName);
6359             break;
6360         case AtomicType::TYPE_INT32:
6361         case AtomicType::TYPE_UINT32:
6362         case AtomicType::TYPE_INT64:
6363         case AtomicType::TYPE_UINT64:
6364             cast = ctx->TruncInst(exprVal, targetType, cOpName);
6365             break;
6366         case AtomicType::TYPE_DOUBLE:
6367             if (fromType->IsVaryingType())
6368                 PerformanceWarning(pos, "Conversion from double to unsigned int is slow. "
6369                                         "Use \"int\" if possible");
6370             cast = ctx->CastInst(llvm::Instruction::FPToUI, // unsigned int
6371                                  exprVal, targetType, cOpName);
6372             break;
6373         default:
6374             FATAL("unimplemented");
6375         }
6376         break;
6377     }
6378     case AtomicType::TYPE_INT32: {
6379         llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::Int32Type : LLVMTypes::Int32VectorType;
6380         switch (fromType->basicType) {
6381         case AtomicType::TYPE_BOOL:
6382             if (fromType->IsVaryingType())
6383                 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6384             cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6385             break;
6386         case AtomicType::TYPE_INT8:
6387         case AtomicType::TYPE_INT16:
6388             cast = ctx->SExtInst(exprVal, targetType, cOpName);
6389             break;
6390         case AtomicType::TYPE_UINT8:
6391         case AtomicType::TYPE_UINT16:
6392             cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6393             break;
6394         case AtomicType::TYPE_INT32:
6395         case AtomicType::TYPE_UINT32:
6396             cast = exprVal;
6397             break;
6398         case AtomicType::TYPE_INT64:
6399         case AtomicType::TYPE_UINT64:
6400             cast = ctx->TruncInst(exprVal, targetType, cOpName);
6401             break;
6402         case AtomicType::TYPE_FLOAT:
6403         case AtomicType::TYPE_DOUBLE:
6404             cast = ctx->CastInst(llvm::Instruction::FPToSI, // signed int
6405                                  exprVal, targetType, cOpName);
6406             break;
6407         default:
6408             FATAL("unimplemented");
6409         }
6410         break;
6411     }
6412     case AtomicType::TYPE_UINT32: {
6413         llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::Int32Type : LLVMTypes::Int32VectorType;
6414         switch (fromType->basicType) {
6415         case AtomicType::TYPE_BOOL:
6416             if (fromType->IsVaryingType())
6417                 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6418             cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6419             break;
6420         case AtomicType::TYPE_INT8:
6421         case AtomicType::TYPE_INT16:
6422             cast = ctx->SExtInst(exprVal, targetType, cOpName);
6423             break;
6424         case AtomicType::TYPE_UINT8:
6425         case AtomicType::TYPE_UINT16:
6426             cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6427             break;
6428         case AtomicType::TYPE_INT32:
6429         case AtomicType::TYPE_UINT32:
6430             cast = exprVal;
6431             break;
6432         case AtomicType::TYPE_FLOAT:
6433             if (fromType->IsVaryingType())
6434                 PerformanceWarning(pos, "Conversion from float to unsigned int is slow. "
6435                                         "Use \"int\" if possible");
6436             cast = ctx->CastInst(llvm::Instruction::FPToUI, // unsigned int
6437                                  exprVal, targetType, cOpName);
6438             break;
6439         case AtomicType::TYPE_INT64:
6440         case AtomicType::TYPE_UINT64:
6441             cast = ctx->TruncInst(exprVal, targetType, cOpName);
6442             break;
6443         case AtomicType::TYPE_DOUBLE:
6444             if (fromType->IsVaryingType())
6445                 PerformanceWarning(pos, "Conversion from double to unsigned int is slow. "
6446                                         "Use \"int\" if possible");
6447             cast = ctx->CastInst(llvm::Instruction::FPToUI, // unsigned int
6448                                  exprVal, targetType, cOpName);
6449             break;
6450         default:
6451             FATAL("unimplemented");
6452         }
6453         break;
6454     }
6455     case AtomicType::TYPE_INT64: {
6456         llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::Int64Type : LLVMTypes::Int64VectorType;
6457         switch (fromType->basicType) {
6458         case AtomicType::TYPE_BOOL:
6459             if (fromType->IsVaryingType())
6460                 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6461             cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6462             break;
6463         case AtomicType::TYPE_INT8:
6464         case AtomicType::TYPE_INT16:
6465         case AtomicType::TYPE_INT32:
6466             cast = ctx->SExtInst(exprVal, targetType, cOpName);
6467             break;
6468         case AtomicType::TYPE_UINT8:
6469         case AtomicType::TYPE_UINT16:
6470         case AtomicType::TYPE_UINT32:
6471             cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6472             break;
6473         case AtomicType::TYPE_INT64:
6474         case AtomicType::TYPE_UINT64:
6475             cast = exprVal;
6476             break;
6477         case AtomicType::TYPE_FLOAT:
6478         case AtomicType::TYPE_DOUBLE:
6479             cast = ctx->CastInst(llvm::Instruction::FPToSI, // signed int
6480                                  exprVal, targetType, cOpName);
6481             break;
6482         default:
6483             FATAL("unimplemented");
6484         }
6485         break;
6486     }
6487     case AtomicType::TYPE_UINT64: {
6488         llvm::Type *targetType = fromType->IsUniformType() ? LLVMTypes::Int64Type : LLVMTypes::Int64VectorType;
6489         switch (fromType->basicType) {
6490         case AtomicType::TYPE_BOOL:
6491             if (fromType->IsVaryingType())
6492                 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6493             cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6494             break;
6495         case AtomicType::TYPE_INT8:
6496         case AtomicType::TYPE_INT16:
6497         case AtomicType::TYPE_INT32:
6498             cast = ctx->SExtInst(exprVal, targetType, cOpName);
6499             break;
6500         case AtomicType::TYPE_UINT8:
6501         case AtomicType::TYPE_UINT16:
6502         case AtomicType::TYPE_UINT32:
6503             cast = ctx->ZExtInst(exprVal, targetType, cOpName);
6504             break;
6505         case AtomicType::TYPE_FLOAT:
6506             if (fromType->IsVaryingType())
6507                 PerformanceWarning(pos, "Conversion from float to unsigned int64 is slow. "
6508                                         "Use \"int64\" if possible");
6509             cast = ctx->CastInst(llvm::Instruction::FPToUI, // signed int
6510                                  exprVal, targetType, cOpName);
6511             break;
6512         case AtomicType::TYPE_INT64:
6513         case AtomicType::TYPE_UINT64:
6514             cast = exprVal;
6515             break;
6516         case AtomicType::TYPE_DOUBLE:
6517             if (fromType->IsVaryingType())
6518                 PerformanceWarning(pos, "Conversion from double to unsigned int64 is slow. "
6519                                         "Use \"int64\" if possible");
6520             cast = ctx->CastInst(llvm::Instruction::FPToUI, // signed int
6521                                  exprVal, targetType, cOpName);
6522             break;
6523         default:
6524             FATAL("unimplemented");
6525         }
6526         break;
6527     }
6528     case AtomicType::TYPE_BOOL: {
6529         switch (fromType->basicType) {
6530         case AtomicType::TYPE_BOOL:
6531             if (fromType->IsVaryingType()) {
6532                 // truncate bool vector values to i1s if necessary.
6533                 exprVal = ctx->SwitchBoolSize(exprVal, LLVMTypes::Int1VectorType, cOpName);
6534             }
6535             cast = exprVal;
6536             break;
6537         case AtomicType::TYPE_INT8:
6538         case AtomicType::TYPE_UINT8: {
6539             llvm::Value *zero =
6540                 fromType->IsUniformType() ? (llvm::Value *)LLVMInt8(0) : (llvm::Value *)LLVMInt8Vector((int8_t)0);
6541             cast = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE, exprVal, zero, cOpName);
6542             break;
6543         }
6544         case AtomicType::TYPE_INT16:
6545         case AtomicType::TYPE_UINT16: {
6546             llvm::Value *zero =
6547                 fromType->IsUniformType() ? (llvm::Value *)LLVMInt16(0) : (llvm::Value *)LLVMInt16Vector((int16_t)0);
6548             cast = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE, exprVal, zero, cOpName);
6549             break;
6550         }
6551         case AtomicType::TYPE_INT32:
6552         case AtomicType::TYPE_UINT32: {
6553             llvm::Value *zero =
6554                 fromType->IsUniformType() ? (llvm::Value *)LLVMInt32(0) : (llvm::Value *)LLVMInt32Vector(0);
6555             cast = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE, exprVal, zero, cOpName);
6556             break;
6557         }
6558         case AtomicType::TYPE_FLOAT: {
6559             llvm::Value *zero =
6560                 fromType->IsUniformType() ? (llvm::Value *)LLVMFloat(0.f) : (llvm::Value *)LLVMFloatVector(0.f);
6561             cast = ctx->CmpInst(llvm::Instruction::FCmp, llvm::CmpInst::FCMP_ONE, exprVal, zero, cOpName);
6562             break;
6563         }
6564         case AtomicType::TYPE_INT64:
6565         case AtomicType::TYPE_UINT64: {
6566             llvm::Value *zero =
6567                 fromType->IsUniformType() ? (llvm::Value *)LLVMInt64(0) : (llvm::Value *)LLVMInt64Vector((int64_t)0);
6568             cast = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE, exprVal, zero, cOpName);
6569             break;
6570         }
6571         case AtomicType::TYPE_DOUBLE: {
6572             llvm::Value *zero =
6573                 fromType->IsUniformType() ? (llvm::Value *)LLVMDouble(0.) : (llvm::Value *)LLVMDoubleVector(0.);
6574             cast = ctx->CmpInst(llvm::Instruction::FCmp, llvm::CmpInst::FCMP_ONE, exprVal, zero, cOpName);
6575             break;
6576         }
6577         default:
6578             FATAL("unimplemented");
6579         }
6580 
6581         if (fromType->IsUniformType()) {
6582             if (toType->IsVaryingType() && LLVMTypes::BoolVectorType != LLVMTypes::Int1VectorType) {
6583                 // extend out to an bool as an i8/i16/i32 from the i1 here.
6584                 // Then we'll turn that into a vector below, the way it
6585                 // does for everyone else...
6586                 Assert(cast);
6587                 cast = ctx->SwitchBoolSize(cast, LLVMTypes::BoolVectorType->getElementType(),
6588                                            llvm::Twine(cast->getName()) + "to_i_bool");
6589             }
6590         } else {
6591             // fromType->IsVaryingType())
6592             cast = ctx->I1VecToBoolVec(cast);
6593         }
6594         break;
6595     }
6596     default:
6597         FATAL("unimplemented");
6598     }
6599 
6600     // If we also want to go from uniform to varying, replicate out the
6601     // value across the vector elements..
6602     if (toType->IsVaryingType() && fromType->IsUniformType())
6603         return ctx->SmearUniform(cast);
6604     else
6605         return cast;
6606 }
6607 
6608 // FIXME: fold this into the FunctionEmitContext::SmearUniform() method?
6609 
6610 /** Converts the given value of the given type to be the varying
6611     equivalent, returning the resulting value.
6612  */
lUniformValueToVarying(FunctionEmitContext * ctx,llvm::Value * value,const Type * type,SourcePos pos)6613 static llvm::Value *lUniformValueToVarying(FunctionEmitContext *ctx, llvm::Value *value, const Type *type,
6614                                            SourcePos pos) {
6615     // nothing to do if it's already varying
6616     if (type->IsVaryingType())
6617         return value;
6618 
6619     // for structs/arrays/vectors, just recursively make their elements
6620     // varying (if needed) and populate the return value.
6621     const CollectionType *collectionType = CastType<CollectionType>(type);
6622     if (collectionType != NULL) {
6623         llvm::Type *llvmType = type->GetAsVaryingType()->LLVMStorageType(g->ctx);
6624         llvm::Value *retValue = llvm::UndefValue::get(llvmType);
6625 
6626         const StructType *structType = CastType<StructType>(type->GetAsVaryingType());
6627 
6628         for (int i = 0; i < collectionType->GetElementCount(); ++i) {
6629             llvm::Value *v = ctx->ExtractInst(value, i, "get_element");
6630             // If struct has "bound uniform" member, we don't need to cast it to varying
6631             if (!(structType != NULL && structType->GetElementType(i)->IsUniformType())) {
6632                 const Type *elemType = collectionType->GetElementType(i);
6633                 // If member is a uniform bool, it needs to be truncated to i1 since
6634                 // uniform  bool in IR is i1 and i8 in struct
6635                 // Consider switching to just a broadcast for bool
6636                 if ((elemType->IsBoolType()) && (CastType<AtomicType>(elemType) != NULL)) {
6637                     v = ctx->TruncInst(v, LLVMTypes::BoolType);
6638                 }
6639                 v = lUniformValueToVarying(ctx, v, elemType, pos);
6640                 // If the extracted element if bool and varying needs to be
6641                 // converted back to i8 vector to insert into varying struct.
6642                 if ((elemType->IsBoolType()) && (CastType<AtomicType>(elemType) != NULL)) {
6643                     v = ctx->SwitchBoolSize(v, LLVMTypes::BoolVectorStorageType);
6644                 }
6645             }
6646             retValue = ctx->InsertInst(retValue, v, i, "set_element");
6647         }
6648         return retValue;
6649     }
6650 
6651     // Otherwise we must have a uniform atomic or pointer type, so smear
6652     // its value across the vector lanes.
6653     if (CastType<AtomicType>(type)) {
6654         return lTypeConvAtomic(ctx, value, CastType<AtomicType>(type->GetAsVaryingType()), CastType<AtomicType>(type),
6655                                pos);
6656     }
6657 
6658     Assert(CastType<PointerType>(type) != NULL);
6659     return ctx->SmearUniform(value);
6660 }
6661 
HasAmbiguousVariability(std::vector<const Expr * > & warn) const6662 bool TypeCastExpr::HasAmbiguousVariability(std::vector<const Expr *> &warn) const {
6663 
6664     if (expr == NULL)
6665         return false;
6666 
6667     const Type *toType = type, *fromType = expr->GetType();
6668     if (toType == NULL || fromType == NULL)
6669         return false;
6670 
6671     if (toType->HasUnboundVariability() && fromType->IsUniformType()) {
6672         warn.push_back(this);
6673         return true;
6674     }
6675 
6676     return false;
6677 }
6678 
PrintAmbiguousVariability() const6679 void TypeCastExpr::PrintAmbiguousVariability() const {
6680     Warning(pos,
6681             "Typecasting to type \"%s\" (variability not specified) "
6682             "from \"uniform\" type \"%s\" results in \"uniform\" variability.\n"
6683             "In the context of function argument it may lead to unexpected behavior. "
6684             "Casting to \"%s\" is recommended.",
6685             (type->GetString()).c_str(), ((expr->GetType())->GetString()).c_str(),
6686             (type->GetAsUniformType()->GetString()).c_str());
6687 }
6688 
GetValue(FunctionEmitContext * ctx) const6689 llvm::Value *TypeCastExpr::GetValue(FunctionEmitContext *ctx) const {
6690     if (!expr)
6691         return NULL;
6692 
6693     ctx->SetDebugPos(pos);
6694     const Type *toType = GetType(), *fromType = expr->GetType();
6695     if (toType == NULL || fromType == NULL) {
6696         AssertPos(pos, m->errorCount > 0);
6697         return NULL;
6698     }
6699 
6700     if (toType->IsVoidType()) {
6701         // emit the code for the expression in case it has side-effects but
6702         // then we're done.
6703         (void)expr->GetValue(ctx);
6704         return NULL;
6705     }
6706 
6707     const PointerType *fromPointerType = CastType<PointerType>(fromType);
6708     const PointerType *toPointerType = CastType<PointerType>(toType);
6709     const ArrayType *toArrayType = CastType<ArrayType>(toType);
6710     const ArrayType *fromArrayType = CastType<ArrayType>(fromType);
6711     if (fromPointerType != NULL) {
6712         if (toArrayType != NULL) {
6713             return expr->GetValue(ctx);
6714         } else if (toPointerType != NULL) {
6715             llvm::Value *value = expr->GetValue(ctx);
6716             if (value == NULL)
6717                 return NULL;
6718 
6719             if (fromPointerType->IsSlice() == false && toPointerType->IsSlice() == true) {
6720                 // Convert from a non-slice pointer to a slice pointer by
6721                 // creating a slice pointer structure with zero offsets.
6722                 if (fromPointerType->IsUniformType())
6723                     value = ctx->MakeSlicePointer(value, LLVMInt32(0));
6724                 else
6725                     value = ctx->MakeSlicePointer(value, LLVMInt32Vector(0));
6726 
6727                 // FIXME: avoid error from unnecessary bitcast when all we
6728                 // need to do is the slice conversion and don't need to
6729                 // also do unif->varying conversions.  But this is really
6730                 // ugly logic.
6731                 if (value->getType() == toType->LLVMType(g->ctx))
6732                     return value;
6733             }
6734 
6735             if (fromType->IsUniformType() && toType->IsUniformType())
6736                 // bitcast to the actual pointer type
6737                 return ctx->BitCastInst(value, toType->LLVMType(g->ctx));
6738             else if (fromType->IsVaryingType() && toType->IsVaryingType()) {
6739                 // both are vectors of ints already, nothing to do at the IR
6740                 // level
6741                 return value;
6742             } else {
6743                 // Uniform -> varying pointer conversion
6744                 AssertPos(pos, fromType->IsUniformType() && toType->IsVaryingType());
6745                 if (fromPointerType->IsSlice()) {
6746                     // For slice pointers, we need to smear out both the
6747                     // pointer and the offset vector
6748                     AssertPos(pos, toPointerType->IsSlice());
6749                     llvm::Value *ptr = ctx->ExtractInst(value, 0);
6750                     llvm::Value *offset = ctx->ExtractInst(value, 1);
6751                     ptr = ctx->PtrToIntInst(ptr);
6752                     ptr = ctx->SmearUniform(ptr);
6753                     offset = ctx->SmearUniform(offset);
6754                     return ctx->MakeSlicePointer(ptr, offset);
6755                 } else {
6756                     // Otherwise we just bitcast it to an int and smear it
6757                     // out to a vector
6758                     value = ctx->PtrToIntInst(value);
6759                     return ctx->SmearUniform(value);
6760                 }
6761             }
6762         } else {
6763             AssertPos(pos, CastType<AtomicType>(toType) != NULL);
6764             if (toType->IsBoolType()) {
6765                 // convert pointer to bool
6766                 llvm::Type *lfu = fromType->GetAsUniformType()->LLVMType(g->ctx);
6767                 llvm::PointerType *llvmFromUnifType = llvm::dyn_cast<llvm::PointerType>(lfu);
6768 
6769                 llvm::Value *nullPtrValue = llvm::ConstantPointerNull::get(llvmFromUnifType);
6770                 if (fromType->IsVaryingType())
6771                     nullPtrValue = ctx->SmearUniform(nullPtrValue);
6772 
6773                 llvm::Value *exprVal = expr->GetValue(ctx);
6774                 llvm::Value *cmp =
6775                     ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE, exprVal, nullPtrValue, "ptr_ne_NULL");
6776 
6777                 if (toType->IsVaryingType()) {
6778                     if (fromType->IsUniformType())
6779                         cmp = ctx->SmearUniform(cmp);
6780                     cmp = ctx->I1VecToBoolVec(cmp);
6781                 }
6782 
6783                 return cmp;
6784             } else {
6785                 // ptr -> int
6786                 llvm::Value *value = expr->GetValue(ctx);
6787                 if (value == NULL)
6788                     return NULL;
6789 
6790                 if (toType->IsVaryingType() && fromType->IsUniformType())
6791                     value = ctx->SmearUniform(value);
6792 
6793                 llvm::Type *llvmToType = toType->LLVMType(g->ctx);
6794                 if (llvmToType == NULL)
6795                     return NULL;
6796                 return ctx->PtrToIntInst(value, llvmToType, "ptr_typecast");
6797             }
6798         }
6799     }
6800 
6801     if (Type::EqualIgnoringConst(toType, fromType))
6802         // There's nothing to do, just return the value.  (LLVM's type
6803         // system doesn't worry about constiness.)
6804         return expr->GetValue(ctx);
6805 
6806     if (fromArrayType != NULL && toPointerType != NULL) {
6807         // implicit array to pointer to first element
6808         Expr *arrayAsPtr = lArrayToPointer(expr);
6809         if (Type::EqualIgnoringConst(arrayAsPtr->GetType(), toPointerType) == false) {
6810             AssertPos(pos,
6811                       PointerType::IsVoidPointer(toPointerType) ||
6812                           Type::EqualIgnoringConst(arrayAsPtr->GetType()->GetAsVaryingType(), toPointerType) == true);
6813             arrayAsPtr = new TypeCastExpr(toPointerType, arrayAsPtr, pos);
6814             arrayAsPtr = ::TypeCheck(arrayAsPtr);
6815             AssertPos(pos, arrayAsPtr != NULL);
6816             arrayAsPtr = ::Optimize(arrayAsPtr);
6817             AssertPos(pos, arrayAsPtr != NULL);
6818         }
6819         AssertPos(pos, Type::EqualIgnoringConst(arrayAsPtr->GetType(), toPointerType));
6820         return arrayAsPtr->GetValue(ctx);
6821     }
6822 
6823     // This also should be caught during typechecking
6824     AssertPos(pos, !(toType->IsUniformType() && fromType->IsVaryingType()));
6825 
6826     if (toArrayType != NULL && fromArrayType != NULL) {
6827         // cast array pointer from [n x foo] to [0 x foo] if needed to be able
6828         // to pass to a function that takes an unsized array as a parameter
6829         if (toArrayType->GetElementCount() != 0 && (toArrayType->GetElementCount() != fromArrayType->GetElementCount()))
6830             Warning(pos, "Type-converting array of length %d to length %d", fromArrayType->GetElementCount(),
6831                     toArrayType->GetElementCount());
6832         AssertPos(pos, Type::EqualIgnoringConst(toArrayType->GetBaseType(), fromArrayType->GetBaseType()));
6833         llvm::Value *v = expr->GetValue(ctx);
6834         llvm::Type *ptype = toType->LLVMType(g->ctx);
6835         return ctx->BitCastInst(v, ptype); //, "array_cast_0size");
6836     }
6837 
6838     const ReferenceType *toReference = CastType<ReferenceType>(toType);
6839     const ReferenceType *fromReference = CastType<ReferenceType>(fromType);
6840     if (toReference && fromReference) {
6841         const Type *toTarget = toReference->GetReferenceTarget();
6842         const Type *fromTarget = fromReference->GetReferenceTarget();
6843 
6844         const ArrayType *toArray = CastType<ArrayType>(toTarget);
6845         const ArrayType *fromArray = CastType<ArrayType>(fromTarget);
6846         if (toArray && fromArray) {
6847             // cast array pointer from [n x foo] to [0 x foo] if needed to be able
6848             // to pass to a function that takes an unsized array as a parameter
6849             if (toArray->GetElementCount() != 0 && (toArray->GetElementCount() != fromArray->GetElementCount()))
6850                 Warning(pos, "Type-converting array of length %d to length %d", fromArray->GetElementCount(),
6851                         toArray->GetElementCount());
6852             AssertPos(pos, Type::EqualIgnoringConst(toArray->GetBaseType(), fromArray->GetBaseType()));
6853             llvm::Value *v = expr->GetValue(ctx);
6854             llvm::Type *ptype = toType->LLVMType(g->ctx);
6855             return ctx->BitCastInst(v, ptype); //, "array_cast_0size");
6856         }
6857 
6858         // Just bitcast it.  See Issue #721
6859         llvm::Value *value = expr->GetValue(ctx);
6860         return ctx->BitCastInst(value, toType->LLVMType(g->ctx), "refcast");
6861     }
6862 
6863     const StructType *toStruct = CastType<StructType>(toType);
6864     const StructType *fromStruct = CastType<StructType>(fromType);
6865     if (toStruct && fromStruct) {
6866         // The only legal type conversions for structs are to go from a
6867         // uniform to a varying instance of the same struct type.
6868         AssertPos(pos, toStruct->IsVaryingType() && fromStruct->IsUniformType() &&
6869                            Type::EqualIgnoringConst(toStruct, fromStruct->GetAsVaryingType()));
6870 
6871         llvm::Value *origValue = expr->GetValue(ctx);
6872         if (!origValue)
6873             return NULL;
6874         return lUniformValueToVarying(ctx, origValue, fromType, pos);
6875     }
6876 
6877     const VectorType *toVector = CastType<VectorType>(toType);
6878     const VectorType *fromVector = CastType<VectorType>(fromType);
6879     if (toVector && fromVector) {
6880         // this should be caught during typechecking
6881         AssertPos(pos, toVector->GetElementCount() == fromVector->GetElementCount());
6882 
6883         llvm::Value *exprVal = expr->GetValue(ctx);
6884         if (!exprVal)
6885             return NULL;
6886 
6887         // Emit instructions to do type conversion of each of the elements
6888         // of the vector.
6889         // FIXME: since uniform vectors are represented as
6890         // llvm::VectorTypes, we should just be able to issue the
6891         // corresponding vector type convert, which should be more
6892         // efficient by avoiding serialization!
6893         llvm::Value *cast = llvm::UndefValue::get(toType->LLVMStorageType(g->ctx));
6894         for (int i = 0; i < toVector->GetElementCount(); ++i) {
6895             llvm::Value *ei = ctx->ExtractInst(exprVal, i);
6896             llvm::Value *conv = lTypeConvAtomic(ctx, ei, toVector->GetElementType(), fromVector->GetElementType(), pos);
6897             if (!conv)
6898                 return NULL;
6899             if ((toVector->GetElementType()->IsBoolType()) &&
6900                 (CastType<AtomicType>(toVector->GetElementType()) != NULL)) {
6901                 conv = ctx->SwitchBoolSize(conv, toVector->GetElementType()->LLVMStorageType(g->ctx));
6902             }
6903 
6904             cast = ctx->InsertInst(cast, conv, i);
6905         }
6906         return cast;
6907     }
6908 
6909     llvm::Value *exprVal = expr->GetValue(ctx);
6910     if (!exprVal)
6911         return NULL;
6912 
6913     const EnumType *fromEnum = CastType<EnumType>(fromType);
6914     const EnumType *toEnum = CastType<EnumType>(toType);
6915     if (fromEnum)
6916         // treat it as an uint32 type for the below and all will be good.
6917         fromType = fromEnum->IsUniformType() ? AtomicType::UniformUInt32 : AtomicType::VaryingUInt32;
6918     if (toEnum)
6919         // treat it as an uint32 type for the below and all will be good.
6920         toType = toEnum->IsUniformType() ? AtomicType::UniformUInt32 : AtomicType::VaryingUInt32;
6921 
6922     const AtomicType *fromAtomic = CastType<AtomicType>(fromType);
6923     // at this point, coming from an atomic type is all that's left...
6924     AssertPos(pos, fromAtomic != NULL);
6925 
6926     if (toVector) {
6927         // scalar -> short vector conversion
6928         llvm::Value *conv = lTypeConvAtomic(ctx, exprVal, toVector->GetElementType(), fromAtomic, pos);
6929         if (!conv)
6930             return NULL;
6931 
6932         llvm::Value *cast = NULL;
6933         llvm::Type *toTypeLLVM = toType->LLVMStorageType(g->ctx);
6934         if (llvm::isa<llvm::VectorType>(toTypeLLVM)) {
6935             // Example uniform float => uniform float<3>
6936             cast = ctx->BroadcastValue(conv, toTypeLLVM);
6937         } else if (llvm::isa<llvm::ArrayType>(toTypeLLVM)) {
6938             // Example varying float => varying float<3>
6939             cast = llvm::UndefValue::get(toType->LLVMStorageType(g->ctx));
6940             for (int i = 0; i < toVector->GetElementCount(); ++i) {
6941                 if ((toVector->GetElementType()->IsBoolType()) &&
6942                     (CastType<AtomicType>(toVector->GetElementType()) != NULL)) {
6943                     conv = ctx->SwitchBoolSize(conv, toVector->GetElementType()->LLVMStorageType(g->ctx));
6944                 }
6945                 // Here's InsertInst produces InsertValueInst.
6946                 cast = ctx->InsertInst(cast, conv, i);
6947             }
6948         } else {
6949             FATAL("TypeCastExpr::GetValue: problem with cast");
6950         }
6951 
6952         return cast;
6953     } else if (toPointerType != NULL) {
6954         // int -> ptr
6955         if (toType->IsVaryingType() && fromType->IsUniformType())
6956             exprVal = ctx->SmearUniform(exprVal);
6957 
6958         llvm::Type *llvmToType = toType->LLVMType(g->ctx);
6959         if (llvmToType == NULL)
6960             return NULL;
6961 
6962         return ctx->IntToPtrInst(exprVal, llvmToType, "int_to_ptr");
6963     } else {
6964         const AtomicType *toAtomic = CastType<AtomicType>(toType);
6965         // typechecking should ensure this is the case
6966         AssertPos(pos, toAtomic != NULL);
6967 
6968         return lTypeConvAtomic(ctx, exprVal, toAtomic, fromAtomic, pos);
6969     }
6970 }
6971 
GetLValue(FunctionEmitContext * ctx) const6972 llvm::Value *TypeCastExpr::GetLValue(FunctionEmitContext *ctx) const {
6973     if (GetLValueType() != NULL) {
6974         return GetValue(ctx);
6975     } else {
6976         return NULL;
6977     }
6978 }
6979 
GetType() const6980 const Type *TypeCastExpr::GetType() const {
6981     // Here we try to resolve situation where (base_type) can be treated as
6982     // (uniform base_type) of (varying base_type). This is a part of function
6983     // TypeCastExpr::TypeCheck. After implementation of operators we
6984     // have to have this functionality here.
6985     if (expr == NULL)
6986         return NULL;
6987     const Type *toType = type, *fromType = expr->GetType();
6988     if (toType == NULL || fromType == NULL)
6989         return NULL;
6990 
6991     if (toType->HasUnboundVariability()) {
6992         if (fromType->IsUniformType()) {
6993             toType = type->ResolveUnboundVariability(Variability::Uniform);
6994         } else {
6995             toType = type->ResolveUnboundVariability(Variability::Varying);
6996         }
6997     }
6998     AssertPos(pos, toType->HasUnboundVariability() == false);
6999     return toType;
7000 }
7001 
GetLValueType() const7002 const Type *TypeCastExpr::GetLValueType() const {
7003     AssertPos(pos, type->HasUnboundVariability() == false);
7004     if (CastType<PointerType>(GetType()) != NULL) {
7005         return type;
7006     } else {
7007         return NULL;
7008     }
7009 }
7010 
lDeconstifyType(const Type * t)7011 static const Type *lDeconstifyType(const Type *t) {
7012     const PointerType *pt = CastType<PointerType>(t);
7013     if (pt != NULL)
7014         return new PointerType(lDeconstifyType(pt->GetBaseType()), pt->GetVariability(), false);
7015     else
7016         return t->GetAsNonConstType();
7017 }
7018 
TypeCheck()7019 Expr *TypeCastExpr::TypeCheck() {
7020     if (expr == NULL)
7021         return NULL;
7022 
7023     const Type *toType = type, *fromType = expr->GetType();
7024     if (toType == NULL || fromType == NULL)
7025         return NULL;
7026 
7027     if (toType->HasUnboundVariability() && fromType->IsUniformType()) {
7028         TypeCastExpr *tce = new TypeCastExpr(toType->GetAsUniformType(), expr, pos);
7029         return ::TypeCheck(tce);
7030     }
7031     type = toType = type->ResolveUnboundVariability(Variability::Varying);
7032 
7033     fromType = lDeconstifyType(fromType);
7034     toType = lDeconstifyType(toType);
7035 
7036     // Anything can be cast to void...
7037     if (toType->IsVoidType())
7038         return this;
7039 
7040     if (fromType->IsVoidType() || (fromType->IsVaryingType() && toType->IsUniformType())) {
7041         Error(pos, "Can't type cast from type \"%s\" to type \"%s\"", fromType->GetString().c_str(),
7042               toType->GetString().c_str());
7043         return NULL;
7044     }
7045 
7046     // First some special cases that we allow only with an explicit type cast
7047     const PointerType *fromPtr = CastType<PointerType>(fromType);
7048     const PointerType *toPtr = CastType<PointerType>(toType);
7049     if (fromPtr != NULL && toPtr != NULL)
7050         // allow explicit typecasts between any two different pointer types
7051         return this;
7052 
7053     const ReferenceType *fromRef = CastType<ReferenceType>(fromType);
7054     const ReferenceType *toRef = CastType<ReferenceType>(toType);
7055     if (fromRef != NULL && toRef != NULL) {
7056         // allow explicit typecasts between any two different reference types
7057         // Issues #721
7058         return this;
7059     }
7060 
7061     const AtomicType *fromAtomic = CastType<AtomicType>(fromType);
7062     const AtomicType *toAtomic = CastType<AtomicType>(toType);
7063     const EnumType *fromEnum = CastType<EnumType>(fromType);
7064     const EnumType *toEnum = CastType<EnumType>(toType);
7065     if ((fromAtomic || fromEnum) && (toAtomic || toEnum))
7066         // Allow explicit casts between all of these
7067         return this;
7068 
7069     // ptr -> int type casts
7070     if (fromPtr != NULL && toAtomic != NULL && toAtomic->IsIntType()) {
7071         bool safeCast =
7072             (toAtomic->basicType == AtomicType::TYPE_INT64 || toAtomic->basicType == AtomicType::TYPE_UINT64);
7073         if (g->target->is32Bit())
7074             safeCast |=
7075                 (toAtomic->basicType == AtomicType::TYPE_INT32 || toAtomic->basicType == AtomicType::TYPE_UINT32);
7076         if (safeCast == false)
7077             Warning(pos,
7078                     "Pointer type cast of type \"%s\" to integer type "
7079                     "\"%s\" may lose information.",
7080                     fromType->GetString().c_str(), toType->GetString().c_str());
7081         return this;
7082     }
7083 
7084     // int -> ptr
7085     if (fromAtomic != NULL && fromAtomic->IsIntType() && toPtr != NULL)
7086         return this;
7087 
7088     // And otherwise see if it's one of the conversions allowed to happen
7089     // implicitly.
7090     Expr *e = TypeConvertExpr(expr, toType, "type cast expression");
7091     if (e == NULL)
7092         return NULL;
7093     else
7094         return e;
7095 }
7096 
Optimize()7097 Expr *TypeCastExpr::Optimize() {
7098     ConstExpr *constExpr = llvm::dyn_cast<ConstExpr>(expr);
7099     if (constExpr == NULL)
7100         // We can't do anything if this isn't a const expr
7101         return this;
7102 
7103     const Type *toType = GetType();
7104     const AtomicType *toAtomic = CastType<AtomicType>(toType);
7105     const EnumType *toEnum = CastType<EnumType>(toType);
7106     // If we're not casting to an atomic or enum type, we can't do anything
7107     // here, since ConstExprs can only represent those two types.  (So
7108     // e.g. we're casting from an int to an int<4>.)
7109     if (toAtomic == NULL && toEnum == NULL)
7110         return this;
7111 
7112     bool forceVarying = toType->IsVaryingType();
7113 
7114     // All of the type conversion smarts we need is already in the
7115     // ConstExpr::GetValues(), etc., methods, so we just need to call the
7116     // appropriate one for the type that this cast is converting to.
7117     AtomicType::BasicType basicType = toAtomic ? toAtomic->basicType : AtomicType::TYPE_UINT32;
7118     switch (basicType) {
7119     case AtomicType::TYPE_BOOL: {
7120         bool bv[ISPC_MAX_NVEC];
7121         constExpr->GetValues(bv, forceVarying);
7122         return new ConstExpr(toType, bv, pos);
7123     }
7124     case AtomicType::TYPE_INT8: {
7125         int8_t iv[ISPC_MAX_NVEC];
7126         constExpr->GetValues(iv, forceVarying);
7127         return new ConstExpr(toType, iv, pos);
7128     }
7129     case AtomicType::TYPE_UINT8: {
7130         uint8_t uv[ISPC_MAX_NVEC];
7131         constExpr->GetValues(uv, forceVarying);
7132         return new ConstExpr(toType, uv, pos);
7133     }
7134     case AtomicType::TYPE_INT16: {
7135         int16_t iv[ISPC_MAX_NVEC];
7136         constExpr->GetValues(iv, forceVarying);
7137         return new ConstExpr(toType, iv, pos);
7138     }
7139     case AtomicType::TYPE_UINT16: {
7140         uint16_t uv[ISPC_MAX_NVEC];
7141         constExpr->GetValues(uv, forceVarying);
7142         return new ConstExpr(toType, uv, pos);
7143     }
7144     case AtomicType::TYPE_INT32: {
7145         int32_t iv[ISPC_MAX_NVEC];
7146         constExpr->GetValues(iv, forceVarying);
7147         return new ConstExpr(toType, iv, pos);
7148     }
7149     case AtomicType::TYPE_UINT32: {
7150         uint32_t uv[ISPC_MAX_NVEC];
7151         constExpr->GetValues(uv, forceVarying);
7152         return new ConstExpr(toType, uv, pos);
7153     }
7154     case AtomicType::TYPE_FLOAT: {
7155         float fv[ISPC_MAX_NVEC];
7156         constExpr->GetValues(fv, forceVarying);
7157         return new ConstExpr(toType, fv, pos);
7158     }
7159     case AtomicType::TYPE_INT64: {
7160         int64_t iv[ISPC_MAX_NVEC];
7161         constExpr->GetValues(iv, forceVarying);
7162         return new ConstExpr(toType, iv, pos);
7163     }
7164     case AtomicType::TYPE_UINT64: {
7165         uint64_t uv[ISPC_MAX_NVEC];
7166         constExpr->GetValues(uv, forceVarying);
7167         return new ConstExpr(toType, uv, pos);
7168     }
7169     case AtomicType::TYPE_DOUBLE: {
7170         double dv[ISPC_MAX_NVEC];
7171         constExpr->GetValues(dv, forceVarying);
7172         return new ConstExpr(toType, dv, pos);
7173     }
7174     default:
7175         FATAL("unimplemented");
7176     }
7177     return this;
7178 }
7179 
EstimateCost() const7180 int TypeCastExpr::EstimateCost() const {
7181     if (llvm::dyn_cast<ConstExpr>(expr) != NULL)
7182         return 0;
7183 
7184     // FIXME: return COST_TYPECAST_COMPLEX when appropriate
7185     return COST_TYPECAST_SIMPLE;
7186 }
7187 
Print() const7188 void TypeCastExpr::Print() const {
7189     printf("[%s] type cast (", GetType()->GetString().c_str());
7190     expr->Print();
7191     printf(")");
7192     pos.Print();
7193 }
7194 
GetBaseSymbol() const7195 Symbol *TypeCastExpr::GetBaseSymbol() const { return expr ? expr->GetBaseSymbol() : NULL; }
7196 
lConvertPointerConstant(llvm::Constant * c,const Type * constType)7197 static llvm::Constant *lConvertPointerConstant(llvm::Constant *c, const Type *constType) {
7198     if (c == NULL || constType->IsUniformType())
7199         return c;
7200 
7201     // Handle conversion to int and then to vector of int or array of int
7202     // (for varying and soa types, respectively)
7203     llvm::Constant *intPtr = llvm::ConstantExpr::getPtrToInt(c, LLVMTypes::PointerIntType);
7204     Assert(constType->IsVaryingType() || constType->IsSOAType());
7205     int count = constType->IsVaryingType() ? g->target->getVectorWidth() : constType->GetSOAWidth();
7206 
7207     std::vector<llvm::Constant *> smear;
7208     for (int i = 0; i < count; ++i)
7209         smear.push_back(intPtr);
7210 
7211     if (constType->IsVaryingType())
7212         return llvm::ConstantVector::get(smear);
7213     else {
7214         llvm::ArrayType *at = llvm::ArrayType::get(LLVMTypes::PointerIntType, count);
7215         return llvm::ConstantArray::get(at, smear);
7216     }
7217 }
7218 
GetConstant(const Type * constType) const7219 std::pair<llvm::Constant *, bool> TypeCastExpr::GetConstant(const Type *constType) const {
7220     // We don't need to worry about most the basic cases where the type
7221     // cast can resolve to a constant here, since the
7222     // TypeCastExpr::Optimize() method generally ends up doing the type
7223     // conversion and returning a ConstExpr, which in turn will have its
7224     // GetConstant() method called.  However, because ConstExpr currently
7225     // can't represent pointer values, we have to handle a few cases
7226     // related to pointers here:
7227     //
7228     // 1. Null pointer (NULL, 0) valued initializers
7229     // 2. Converting function types to pointer-to-function types
7230     // 3. And converting these from uniform to the varying/soa equivalents.
7231     //
7232 
7233     if ((CastType<PointerType>(constType) == NULL) && (llvm::dyn_cast<SizeOfExpr>(expr) == NULL))
7234         return std::pair<llvm::Constant *, bool>(NULL, false);
7235 
7236     llvm::Value *ptr = NULL;
7237     if (GetBaseSymbol())
7238         ptr = GetBaseSymbol()->storagePtr;
7239 
7240     if (ptr && llvm::dyn_cast<llvm::GlobalVariable>(ptr)) {
7241         if (CastType<ArrayType>(expr->GetType())) {
7242             if (llvm::Constant *c = llvm::dyn_cast<llvm::Constant>(ptr)) {
7243                 llvm::Value *offsets[2] = {LLVMInt32(0), LLVMInt32(0)};
7244                 llvm::ArrayRef<llvm::Value *> arrayRef(&offsets[0], &offsets[2]);
7245                 llvm::Value *resultPtr = llvm::ConstantExpr::getGetElementPtr(PTYPE(c), c, arrayRef);
7246                 if (resultPtr->getType() == constType->LLVMType(g->ctx)) {
7247                     llvm::Constant *ret = llvm::dyn_cast<llvm::Constant>(resultPtr);
7248                     return std::pair<llvm::Constant *, bool>(ret, false);
7249                 }
7250             }
7251         }
7252     }
7253 
7254     std::pair<llvm::Constant *, bool> cPair = expr->GetConstant(constType->GetAsUniformType());
7255     llvm::Constant *c = cPair.first;
7256     return std::pair<llvm::Constant *, bool>(lConvertPointerConstant(c, constType), cPair.second);
7257 }
7258 
7259 ///////////////////////////////////////////////////////////////////////////
7260 // ReferenceExpr
7261 
ReferenceExpr(Expr * e,SourcePos p)7262 ReferenceExpr::ReferenceExpr(Expr *e, SourcePos p) : Expr(p, ReferenceExprID) { expr = e; }
7263 
GetValue(FunctionEmitContext * ctx) const7264 llvm::Value *ReferenceExpr::GetValue(FunctionEmitContext *ctx) const {
7265     ctx->SetDebugPos(pos);
7266     if (expr == NULL) {
7267         AssertPos(pos, m->errorCount > 0);
7268         return NULL;
7269     }
7270 
7271     llvm::Value *value = expr->GetLValue(ctx);
7272     if (value != NULL)
7273         return value;
7274 
7275     // value is NULL if the expression is a temporary; in this case, we'll
7276     // allocate storage for it so that we can return the pointer to that...
7277     const Type *type;
7278     if ((type = expr->GetType()) == NULL || type->LLVMType(g->ctx) == NULL) {
7279         AssertPos(pos, m->errorCount > 0);
7280         return NULL;
7281     }
7282 
7283     value = expr->GetValue(ctx);
7284     if (value == NULL) {
7285         AssertPos(pos, m->errorCount > 0);
7286         return NULL;
7287     }
7288 
7289     llvm::Value *ptr = ctx->AllocaInst(type);
7290     ctx->StoreInst(value, ptr, type, expr->GetType()->IsUniformType());
7291     return ptr;
7292 }
7293 
GetBaseSymbol() const7294 Symbol *ReferenceExpr::GetBaseSymbol() const { return expr ? expr->GetBaseSymbol() : NULL; }
7295 
GetType() const7296 const Type *ReferenceExpr::GetType() const {
7297     if (!expr)
7298         return NULL;
7299 
7300     const Type *type = expr->GetType();
7301     if (!type)
7302         return NULL;
7303 
7304     return new ReferenceType(type);
7305 }
7306 
GetLValueType() const7307 const Type *ReferenceExpr::GetLValueType() const {
7308     if (!expr)
7309         return NULL;
7310 
7311     const Type *type = expr->GetType();
7312     if (!type)
7313         return NULL;
7314 
7315     return PointerType::GetUniform(type);
7316 }
7317 
Optimize()7318 Expr *ReferenceExpr::Optimize() {
7319     if (expr == NULL)
7320         return NULL;
7321     return this;
7322 }
7323 
TypeCheck()7324 Expr *ReferenceExpr::TypeCheck() {
7325     if (expr == NULL)
7326         return NULL;
7327     return this;
7328 }
7329 
EstimateCost() const7330 int ReferenceExpr::EstimateCost() const { return 0; }
7331 
Print() const7332 void ReferenceExpr::Print() const {
7333     if (expr == NULL || GetType() == NULL)
7334         return;
7335 
7336     printf("[%s] &(", GetType()->GetString().c_str());
7337     expr->Print();
7338     printf(")");
7339     pos.Print();
7340 }
7341 
7342 ///////////////////////////////////////////////////////////////////////////
7343 // DerefExpr
7344 
DerefExpr(Expr * e,SourcePos p,unsigned scid)7345 DerefExpr::DerefExpr(Expr *e, SourcePos p, unsigned scid) : Expr(p, scid) { expr = e; }
7346 
GetValue(FunctionEmitContext * ctx) const7347 llvm::Value *DerefExpr::GetValue(FunctionEmitContext *ctx) const {
7348     if (expr == NULL)
7349         return NULL;
7350     llvm::Value *ptr = expr->GetValue(ctx);
7351     if (ptr == NULL)
7352         return NULL;
7353     const Type *type = expr->GetType();
7354     if (type == NULL)
7355         return NULL;
7356 
7357     if (lVaryingStructHasUniformMember(type, pos))
7358         return NULL;
7359 
7360     // If dealing with 'varying * varying' add required offsets.
7361     ptr = lAddVaryingOffsetsIfNeeded(ctx, ptr, type);
7362 
7363     Symbol *baseSym = expr->GetBaseSymbol();
7364     llvm::Value *mask = baseSym ? lMaskForSymbol(baseSym, ctx) : ctx->GetFullMask();
7365 
7366     ctx->SetDebugPos(pos);
7367     return ctx->LoadInst(ptr, mask, type);
7368 }
7369 
GetLValue(FunctionEmitContext * ctx) const7370 llvm::Value *DerefExpr::GetLValue(FunctionEmitContext *ctx) const {
7371     if (expr == NULL)
7372         return NULL;
7373     return expr->GetValue(ctx);
7374 }
7375 
GetLValueType() const7376 const Type *DerefExpr::GetLValueType() const {
7377     if (expr == NULL)
7378         return NULL;
7379     return expr->GetType();
7380 }
7381 
GetBaseSymbol() const7382 Symbol *DerefExpr::GetBaseSymbol() const { return expr ? expr->GetBaseSymbol() : NULL; }
7383 
Optimize()7384 Expr *DerefExpr::Optimize() {
7385     if (expr == NULL)
7386         return NULL;
7387     return this;
7388 }
7389 
7390 ///////////////////////////////////////////////////////////////////////////
7391 // PtrDerefExpr
7392 
PtrDerefExpr(Expr * e,SourcePos p)7393 PtrDerefExpr::PtrDerefExpr(Expr *e, SourcePos p) : DerefExpr(e, p, PtrDerefExprID) {}
7394 
GetType() const7395 const Type *PtrDerefExpr::GetType() const {
7396     const Type *type;
7397     if (expr == NULL || (type = expr->GetType()) == NULL) {
7398         AssertPos(pos, m->errorCount > 0);
7399         return NULL;
7400     }
7401     AssertPos(pos, CastType<PointerType>(type) != NULL);
7402 
7403     if (type->IsUniformType())
7404         return type->GetBaseType();
7405     else
7406         return type->GetBaseType()->GetAsVaryingType();
7407 }
7408 
TypeCheck()7409 Expr *PtrDerefExpr::TypeCheck() {
7410     const Type *type;
7411     if (expr == NULL || (type = expr->GetType()) == NULL) {
7412         AssertPos(pos, m->errorCount > 0);
7413         return NULL;
7414     }
7415 
7416     if (const PointerType *pt = CastType<PointerType>(type)) {
7417         if (pt->GetBaseType()->IsVoidType()) {
7418             Error(pos, "Illegal to dereference void pointer type \"%s\".", type->GetString().c_str());
7419             return NULL;
7420         }
7421     } else {
7422         Error(pos, "Illegal to dereference non-pointer type \"%s\".", type->GetString().c_str());
7423         return NULL;
7424     }
7425 
7426     return this;
7427 }
7428 
EstimateCost() const7429 int PtrDerefExpr::EstimateCost() const {
7430     const Type *type;
7431     if (expr == NULL || (type = expr->GetType()) == NULL) {
7432         AssertPos(pos, m->errorCount > 0);
7433         return 0;
7434     }
7435 
7436     if (type->IsVaryingType())
7437         // Be pessimistic; some of these will later be optimized into
7438         // vector loads/stores..
7439         return COST_GATHER + COST_DEREF;
7440     else
7441         return COST_DEREF;
7442 }
7443 
Print() const7444 void PtrDerefExpr::Print() const {
7445     if (expr == NULL || GetType() == NULL)
7446         return;
7447 
7448     printf("[%s] *(", GetType()->GetString().c_str());
7449     expr->Print();
7450     printf(")");
7451     pos.Print();
7452 }
7453 
7454 ///////////////////////////////////////////////////////////////////////////
7455 // RefDerefExpr
7456 
RefDerefExpr(Expr * e,SourcePos p)7457 RefDerefExpr::RefDerefExpr(Expr *e, SourcePos p) : DerefExpr(e, p, RefDerefExprID) {}
7458 
GetType() const7459 const Type *RefDerefExpr::GetType() const {
7460     const Type *type;
7461     if (expr == NULL || (type = expr->GetType()) == NULL) {
7462         AssertPos(pos, m->errorCount > 0);
7463         return NULL;
7464     }
7465 
7466     AssertPos(pos, CastType<ReferenceType>(type) != NULL);
7467     return type->GetReferenceTarget();
7468 }
7469 
TypeCheck()7470 Expr *RefDerefExpr::TypeCheck() {
7471     const Type *type;
7472     if (expr == NULL || (type = expr->GetType()) == NULL) {
7473         AssertPos(pos, m->errorCount > 0);
7474         return NULL;
7475     }
7476 
7477     // We only create RefDerefExprs internally for references in
7478     // expressions, so we should never create one with a non-reference
7479     // expression...
7480     AssertPos(pos, CastType<ReferenceType>(type) != NULL);
7481 
7482     return this;
7483 }
7484 
EstimateCost() const7485 int RefDerefExpr::EstimateCost() const {
7486     if (expr == NULL)
7487         return 0;
7488 
7489     return COST_DEREF;
7490 }
7491 
Print() const7492 void RefDerefExpr::Print() const {
7493     if (expr == NULL || GetType() == NULL)
7494         return;
7495 
7496     printf("[%s] deref-reference (", GetType()->GetString().c_str());
7497     expr->Print();
7498     printf(")");
7499     pos.Print();
7500 }
7501 
7502 ///////////////////////////////////////////////////////////////////////////
7503 // AddressOfExpr
7504 
AddressOfExpr(Expr * e,SourcePos p)7505 AddressOfExpr::AddressOfExpr(Expr *e, SourcePos p) : Expr(p, AddressOfExprID), expr(e) {}
7506 
GetValue(FunctionEmitContext * ctx) const7507 llvm::Value *AddressOfExpr::GetValue(FunctionEmitContext *ctx) const {
7508     ctx->SetDebugPos(pos);
7509     if (expr == NULL)
7510         return NULL;
7511 
7512     const Type *exprType = expr->GetType();
7513     if (CastType<ReferenceType>(exprType) != NULL || CastType<FunctionType>(exprType) != NULL)
7514         return expr->GetValue(ctx);
7515     else
7516         return expr->GetLValue(ctx);
7517 }
7518 
GetType() const7519 const Type *AddressOfExpr::GetType() const {
7520     if (expr == NULL)
7521         return NULL;
7522 
7523     const Type *exprType = expr->GetType();
7524     if (CastType<ReferenceType>(exprType) != NULL)
7525         return PointerType::GetUniform(exprType->GetReferenceTarget());
7526 
7527     const Type *t = expr->GetLValueType();
7528     if (t != NULL)
7529         return t;
7530     else {
7531         t = expr->GetType();
7532         if (t == NULL) {
7533             AssertPos(pos, m->errorCount > 0);
7534             return NULL;
7535         }
7536         return PointerType::GetUniform(t);
7537     }
7538 }
7539 
GetLValueType() const7540 const Type *AddressOfExpr::GetLValueType() const {
7541     if (!expr)
7542         return NULL;
7543 
7544     const Type *type = expr->GetType();
7545     if (!type)
7546         return NULL;
7547 
7548     return PointerType::GetUniform(type);
7549 }
7550 
GetBaseSymbol() const7551 Symbol *AddressOfExpr::GetBaseSymbol() const { return expr ? expr->GetBaseSymbol() : NULL; }
7552 
Print() const7553 void AddressOfExpr::Print() const {
7554     printf("&(");
7555     if (expr)
7556         expr->Print();
7557     else
7558         printf("NULL expr");
7559     printf(")");
7560     pos.Print();
7561 }
7562 
TypeCheck()7563 Expr *AddressOfExpr::TypeCheck() {
7564     const Type *exprType;
7565     if (expr == NULL || (exprType = expr->GetType()) == NULL) {
7566         AssertPos(pos, m->errorCount > 0);
7567         return NULL;
7568     }
7569 
7570     if (CastType<ReferenceType>(exprType) != NULL || CastType<FunctionType>(exprType) != NULL) {
7571         return this;
7572     }
7573 
7574     if (expr->GetLValueType() != NULL)
7575         return this;
7576 
7577     Error(expr->pos, "Illegal to take address of non-lvalue or function.");
7578     return NULL;
7579 }
7580 
Optimize()7581 Expr *AddressOfExpr::Optimize() { return this; }
7582 
EstimateCost() const7583 int AddressOfExpr::EstimateCost() const { return 0; }
7584 
GetConstant(const Type * type) const7585 std::pair<llvm::Constant *, bool> AddressOfExpr::GetConstant(const Type *type) const {
7586     if (expr == NULL || expr->GetType() == NULL) {
7587         AssertPos(pos, m->errorCount > 0);
7588         return std::pair<llvm::Constant *, bool>(NULL, false);
7589     }
7590 
7591     const PointerType *pt = CastType<PointerType>(type);
7592     if (pt == NULL)
7593         return std::pair<llvm::Constant *, bool>(NULL, false);
7594 
7595     bool isNotValidForMultiTargetGlobal = false;
7596     const FunctionType *ft = CastType<FunctionType>(pt->GetBaseType());
7597     if (ft != NULL) {
7598         std::pair<llvm::Constant *, bool> cPair = expr->GetConstant(ft);
7599         llvm::Constant *c = cPair.first;
7600         return std::pair<llvm::Constant *, bool>(lConvertPointerConstant(c, type), cPair.second);
7601     }
7602     llvm::Value *ptr = NULL;
7603     if (GetBaseSymbol())
7604         ptr = GetBaseSymbol()->storagePtr;
7605     if (ptr && llvm::dyn_cast<llvm::GlobalVariable>(ptr)) {
7606         const Type *eTYPE = GetType();
7607         if (type->LLVMType(g->ctx) == eTYPE->LLVMType(g->ctx)) {
7608             if (llvm::dyn_cast<SymbolExpr>(expr) != NULL) {
7609                 return std::pair<llvm::Constant *, bool>(llvm::cast<llvm::Constant>(ptr),
7610                                                          isNotValidForMultiTargetGlobal);
7611 
7612             } else if (IndexExpr *IExpr = llvm::dyn_cast<IndexExpr>(expr)) {
7613                 std::vector<llvm::Value *> gepIndex;
7614                 Expr *mBaseExpr = NULL;
7615                 while (IExpr) {
7616                     std::pair<llvm::Constant *, bool> cIndexPair = IExpr->index->GetConstant(IExpr->index->GetType());
7617                     llvm::Constant *cIndex = cIndexPair.first;
7618                     if (cIndex == NULL)
7619                         return std::pair<llvm::Constant *, bool>(NULL, false);
7620                     gepIndex.insert(gepIndex.begin(), cIndex);
7621                     mBaseExpr = IExpr->baseExpr;
7622                     IExpr = llvm::dyn_cast<IndexExpr>(mBaseExpr);
7623                     isNotValidForMultiTargetGlobal = isNotValidForMultiTargetGlobal || cIndexPair.second;
7624                 }
7625                 // The base expression needs to be a global symbol so that the
7626                 // address is a constant.
7627                 if (llvm::dyn_cast<SymbolExpr>(mBaseExpr) == NULL)
7628                     return std::pair<llvm::Constant *, bool>(NULL, false);
7629                 gepIndex.insert(gepIndex.begin(), LLVMInt64(0));
7630                 llvm::Constant *c = llvm::cast<llvm::Constant>(ptr);
7631                 llvm::Constant *c1 = llvm::ConstantExpr::getGetElementPtr(PTYPE(c), c, gepIndex);
7632                 return std::pair<llvm::Constant *, bool>(c1, isNotValidForMultiTargetGlobal);
7633             }
7634         }
7635     }
7636     return std::pair<llvm::Constant *, bool>(NULL, false);
7637 }
7638 
7639 ///////////////////////////////////////////////////////////////////////////
7640 // SizeOfExpr
7641 
SizeOfExpr(Expr * e,SourcePos p)7642 SizeOfExpr::SizeOfExpr(Expr *e, SourcePos p) : Expr(p, SizeOfExprID), expr(e), type(NULL) {}
7643 
SizeOfExpr(const Type * t,SourcePos p)7644 SizeOfExpr::SizeOfExpr(const Type *t, SourcePos p) : Expr(p, SizeOfExprID), expr(NULL), type(t) {
7645     type = type->ResolveUnboundVariability(Variability::Varying);
7646 }
7647 
GetValue(FunctionEmitContext * ctx) const7648 llvm::Value *SizeOfExpr::GetValue(FunctionEmitContext *ctx) const {
7649     ctx->SetDebugPos(pos);
7650     const Type *t = expr ? expr->GetType() : type;
7651     if (t == NULL)
7652         return NULL;
7653 
7654     llvm::Type *llvmType = t->LLVMType(g->ctx);
7655     if (llvmType == NULL)
7656         return NULL;
7657 
7658     return g->target->SizeOf(llvmType, ctx->GetCurrentBasicBlock());
7659 }
7660 
GetType() const7661 const Type *SizeOfExpr::GetType() const {
7662     return (g->target->is32Bit() || g->opt.force32BitAddressing) ? AtomicType::UniformUInt32
7663                                                                  : AtomicType::UniformUInt64;
7664 }
7665 
Print() const7666 void SizeOfExpr::Print() const {
7667     printf("Sizeof (");
7668     if (expr != NULL)
7669         expr->Print();
7670     const Type *t = expr ? expr->GetType() : type;
7671     if (t != NULL)
7672         printf(" [type %s]", t->GetString().c_str());
7673     printf(")");
7674     pos.Print();
7675 }
7676 
TypeCheck()7677 Expr *SizeOfExpr::TypeCheck() {
7678     // Can't compute the size of a struct without a definition
7679     if (type != NULL && CastType<UndefinedStructType>(type) != NULL) {
7680         Error(pos,
7681               "Can't compute the size of declared but not defined "
7682               "struct type \"%s\".",
7683               type->GetString().c_str());
7684         return NULL;
7685     }
7686 
7687     return this;
7688 }
7689 
Optimize()7690 Expr *SizeOfExpr::Optimize() { return this; }
7691 
EstimateCost() const7692 int SizeOfExpr::EstimateCost() const { return 0; }
7693 
GetConstant(const Type * rtype) const7694 std::pair<llvm::Constant *, bool> SizeOfExpr::GetConstant(const Type *rtype) const {
7695     const Type *t = expr ? expr->GetType() : type;
7696     if (t == NULL)
7697         return std::pair<llvm::Constant *, bool>(NULL, false);
7698 
7699     bool isNotValidForMultiTargetGlobal = false;
7700     if (t->IsVaryingType())
7701         isNotValidForMultiTargetGlobal = true;
7702 
7703     llvm::Type *llvmType = t->LLVMType(g->ctx);
7704     if (llvmType == NULL)
7705         return std::pair<llvm::Constant *, bool>(NULL, false);
7706 
7707     uint64_t byteSize = g->target->getDataLayout()->getTypeStoreSize(llvmType);
7708     return std::pair<llvm::Constant *, bool>(llvm::ConstantInt::get(rtype->LLVMType(g->ctx), byteSize),
7709                                              isNotValidForMultiTargetGlobal);
7710 }
7711 
7712 ///////////////////////////////////////////////////////////////////////////
7713 // AllocaExpr
7714 
AllocaExpr(Expr * e,SourcePos p)7715 AllocaExpr::AllocaExpr(Expr *e, SourcePos p) : Expr(p, AllocaExprID), expr(e) {}
7716 
GetValue(FunctionEmitContext * ctx) const7717 llvm::Value *AllocaExpr::GetValue(FunctionEmitContext *ctx) const {
7718     ctx->SetDebugPos(pos);
7719     if (expr == NULL)
7720         return NULL;
7721     llvm::Value *llvmValue = expr->GetValue(ctx);
7722     if (llvmValue == NULL)
7723         return NULL;
7724     llvm::Value *resultPtr = ctx->AllocaInst((LLVMTypes::VoidPointerType)->getElementType(), llvmValue, "allocaExpr",
7725                                              16, false); // 16 byte stack alignment.
7726     return resultPtr;
7727 }
7728 
GetType() const7729 const Type *AllocaExpr::GetType() const { return PointerType::Void; }
7730 
Print() const7731 void AllocaExpr::Print() const {
7732     printf("AllocaExpr (");
7733     if (expr != NULL)
7734         expr->Print();
7735     const Type *t = expr ? expr->GetType() : NULL;
7736     if (t != NULL)
7737         printf(" [type %s]", t->GetString().c_str());
7738     printf(")");
7739     pos.Print();
7740 }
7741 
TypeCheck()7742 Expr *AllocaExpr::TypeCheck() {
7743     if (expr == NULL) {
7744         return NULL;
7745     }
7746 
7747     if (g->target->isGenXTarget()) {
7748         Error(pos, "\"alloca()\" is not supported for genx-* targets yet.");
7749         return NULL;
7750     }
7751     const Type *argType = expr ? expr->GetType() : NULL;
7752     const Type *sizeType = m->symbolTable->LookupType("size_t");
7753     Assert(sizeType != NULL);
7754     if (!Type::Equal(sizeType->GetAsUniformType(), expr->GetType())) {
7755         expr = TypeConvertExpr(expr, sizeType->GetAsUniformType(), "Alloca_arg");
7756     }
7757     if (expr == NULL) {
7758         Error(pos, "\"alloca()\" cannot have an argument of type \"%s\".", argType->GetString().c_str());
7759         return NULL;
7760     }
7761 
7762     return this;
7763 }
7764 
Optimize()7765 Expr *AllocaExpr::Optimize() { return this; }
7766 
EstimateCost() const7767 int AllocaExpr::EstimateCost() const { return 0; }
7768 
7769 ///////////////////////////////////////////////////////////////////////////
7770 // SymbolExpr
7771 
SymbolExpr(Symbol * s,SourcePos p)7772 SymbolExpr::SymbolExpr(Symbol *s, SourcePos p) : Expr(p, SymbolExprID) { symbol = s; }
7773 
GetValue(FunctionEmitContext * ctx) const7774 llvm::Value *SymbolExpr::GetValue(FunctionEmitContext *ctx) const {
7775     // storagePtr may be NULL due to an earlier compilation error
7776     if (!symbol || !symbol->storagePtr)
7777         return NULL;
7778     ctx->SetDebugPos(pos);
7779 
7780     std::string loadName = symbol->name + std::string("_load");
7781 #ifdef ISPC_GENX_ENABLED
7782     // TODO: this is a temporary workaround and will be changed as part
7783     // of SPIR-V emitting solution
7784     if (ctx->emitGenXHardwareMask() && symbol->name == "__mask") {
7785         return ctx->GenXSimdCFPredicate(LLVMMaskAllOn);
7786     }
7787 #endif
7788     return ctx->LoadInst(symbol->storagePtr, symbol->type, loadName.c_str());
7789 }
7790 
GetLValue(FunctionEmitContext * ctx) const7791 llvm::Value *SymbolExpr::GetLValue(FunctionEmitContext *ctx) const {
7792     if (symbol == NULL)
7793         return NULL;
7794     ctx->SetDebugPos(pos);
7795     return symbol->storagePtr;
7796 }
7797 
GetLValueType() const7798 const Type *SymbolExpr::GetLValueType() const {
7799     if (symbol == NULL)
7800         return NULL;
7801 
7802     if (CastType<ReferenceType>(symbol->type) != NULL)
7803         return PointerType::GetUniform(symbol->type->GetReferenceTarget());
7804     else
7805         return PointerType::GetUniform(symbol->type);
7806 }
7807 
GetBaseSymbol() const7808 Symbol *SymbolExpr::GetBaseSymbol() const { return symbol; }
7809 
GetType() const7810 const Type *SymbolExpr::GetType() const { return symbol ? symbol->type : NULL; }
7811 
TypeCheck()7812 Expr *SymbolExpr::TypeCheck() { return this; }
7813 
Optimize()7814 Expr *SymbolExpr::Optimize() {
7815     if (symbol == NULL)
7816         return NULL;
7817     else if (symbol->constValue != NULL) {
7818         AssertPos(pos, GetType()->IsConstType());
7819         return new ConstExpr(symbol->constValue, pos);
7820     } else
7821         return this;
7822 }
7823 
EstimateCost() const7824 int SymbolExpr::EstimateCost() const {
7825     // Be optimistic and assume it's in a register or can be used as a
7826     // memory operand..
7827     return 0;
7828 }
7829 
Print() const7830 void SymbolExpr::Print() const {
7831     if (symbol == NULL || GetType() == NULL)
7832         return;
7833 
7834     printf("[%s] sym: (%s)", GetType()->GetString().c_str(), symbol->name.c_str());
7835     pos.Print();
7836 }
7837 
7838 ///////////////////////////////////////////////////////////////////////////
7839 // FunctionSymbolExpr
7840 
FunctionSymbolExpr(const char * n,const std::vector<Symbol * > & candidates,SourcePos p)7841 FunctionSymbolExpr::FunctionSymbolExpr(const char *n, const std::vector<Symbol *> &candidates, SourcePos p)
7842     : Expr(p, FunctionSymbolExprID) {
7843     name = n;
7844     candidateFunctions = candidates;
7845     matchingFunc = (candidates.size() == 1) ? candidates[0] : NULL;
7846     triedToResolve = false;
7847 }
7848 
GetType() const7849 const Type *FunctionSymbolExpr::GetType() const {
7850     if (triedToResolve == false && matchingFunc == NULL) {
7851         Error(pos, "Ambiguous use of overloaded function \"%s\".", name.c_str());
7852         return NULL;
7853     }
7854 
7855     return matchingFunc ? matchingFunc->type : NULL;
7856 }
7857 
GetValue(FunctionEmitContext * ctx) const7858 llvm::Value *FunctionSymbolExpr::GetValue(FunctionEmitContext *ctx) const {
7859     return matchingFunc ? matchingFunc->function : NULL;
7860 }
7861 
GetBaseSymbol() const7862 Symbol *FunctionSymbolExpr::GetBaseSymbol() const { return matchingFunc; }
7863 
TypeCheck()7864 Expr *FunctionSymbolExpr::TypeCheck() { return this; }
7865 
Optimize()7866 Expr *FunctionSymbolExpr::Optimize() { return this; }
7867 
EstimateCost() const7868 int FunctionSymbolExpr::EstimateCost() const { return 0; }
7869 
Print() const7870 void FunctionSymbolExpr::Print() const {
7871     if (!matchingFunc || !GetType())
7872         return;
7873 
7874     printf("[%s] fun sym (%s)", GetType()->GetString().c_str(), matchingFunc->name.c_str());
7875     pos.Print();
7876 }
7877 
GetConstant(const Type * type) const7878 std::pair<llvm::Constant *, bool> FunctionSymbolExpr::GetConstant(const Type *type) const {
7879     if (matchingFunc == NULL || matchingFunc->function == NULL)
7880         return std::pair<llvm::Constant *, bool>(NULL, false);
7881 
7882     const FunctionType *ft = CastType<FunctionType>(type);
7883     if (ft == NULL)
7884         return std::pair<llvm::Constant *, bool>(NULL, false);
7885 
7886     if (Type::Equal(type, matchingFunc->type) == false) {
7887         Error(pos,
7888               "Type of function symbol \"%s\" doesn't match expected type "
7889               "\"%s\".",
7890               matchingFunc->type->GetString().c_str(), type->GetString().c_str());
7891         return std::pair<llvm::Constant *, bool>(NULL, false);
7892     }
7893 
7894     return std::pair<llvm::Constant *, bool>(matchingFunc->function, false);
7895 }
7896 
lGetOverloadCandidateMessage(const std::vector<Symbol * > & funcs,const std::vector<const Type * > & argTypes,const std::vector<bool> * argCouldBeNULL)7897 static std::string lGetOverloadCandidateMessage(const std::vector<Symbol *> &funcs,
7898                                                 const std::vector<const Type *> &argTypes,
7899                                                 const std::vector<bool> *argCouldBeNULL) {
7900     std::string message = "Passed types: (";
7901     for (unsigned int i = 0; i < argTypes.size(); ++i) {
7902         if (argTypes[i] != NULL)
7903             message += argTypes[i]->GetString();
7904         else
7905             message += "(unknown type)";
7906         message += (i < argTypes.size() - 1) ? ", " : ")\n";
7907     }
7908 
7909     for (unsigned int i = 0; i < funcs.size(); ++i) {
7910         const FunctionType *ft = CastType<FunctionType>(funcs[i]->type);
7911         Assert(ft != NULL);
7912         message += "Candidate: ";
7913         message += ft->GetString();
7914         if (i < funcs.size() - 1)
7915             message += "\n";
7916     }
7917     return message;
7918 }
7919 
7920 /** Helper function used for function overload resolution: returns true if
7921     converting the argument to the call type only requires a type
7922     conversion that won't lose information.  Otherwise return false.
7923   */
lIsMatchWithTypeWidening(const Type * callType,const Type * funcArgType)7924 static bool lIsMatchWithTypeWidening(const Type *callType, const Type *funcArgType) {
7925     const AtomicType *callAt = CastType<AtomicType>(callType);
7926     const AtomicType *funcAt = CastType<AtomicType>(funcArgType);
7927     if (callAt == NULL || funcAt == NULL)
7928         return false;
7929 
7930     if (callAt->IsUniformType() != funcAt->IsUniformType())
7931         return false;
7932 
7933     switch (callAt->basicType) {
7934     case AtomicType::TYPE_BOOL:
7935         return true;
7936     case AtomicType::TYPE_INT8:
7937     case AtomicType::TYPE_UINT8:
7938         return (funcAt->basicType != AtomicType::TYPE_BOOL);
7939     case AtomicType::TYPE_INT16:
7940     case AtomicType::TYPE_UINT16:
7941         return (funcAt->basicType != AtomicType::TYPE_BOOL && funcAt->basicType != AtomicType::TYPE_INT8 &&
7942                 funcAt->basicType != AtomicType::TYPE_UINT8);
7943     case AtomicType::TYPE_INT32:
7944     case AtomicType::TYPE_UINT32:
7945         return (funcAt->basicType == AtomicType::TYPE_INT32 || funcAt->basicType == AtomicType::TYPE_UINT32 ||
7946                 funcAt->basicType == AtomicType::TYPE_INT64 || funcAt->basicType == AtomicType::TYPE_UINT64);
7947     case AtomicType::TYPE_FLOAT:
7948         return (funcAt->basicType == AtomicType::TYPE_DOUBLE);
7949     case AtomicType::TYPE_INT64:
7950     case AtomicType::TYPE_UINT64:
7951         return (funcAt->basicType == AtomicType::TYPE_INT64 || funcAt->basicType == AtomicType::TYPE_UINT64);
7952     case AtomicType::TYPE_DOUBLE:
7953         return false;
7954     default:
7955         FATAL("Unhandled atomic type");
7956         return false;
7957     }
7958 }
7959 
7960 /* Returns the set of function overloads that are potential matches, given
7961    argCount values being passed as arguments to the function call.
7962  */
getCandidateFunctions(int argCount) const7963 std::vector<Symbol *> FunctionSymbolExpr::getCandidateFunctions(int argCount) const {
7964     std::vector<Symbol *> ret;
7965     for (int i = 0; i < (int)candidateFunctions.size(); ++i) {
7966         const FunctionType *ft = CastType<FunctionType>(candidateFunctions[i]->type);
7967         AssertPos(pos, ft != NULL);
7968 
7969         // There's no way to match if the caller is passing more arguments
7970         // than this function instance takes.
7971         if (argCount > ft->GetNumParameters())
7972             continue;
7973 
7974         // Not enough arguments, and no default argument value to save us
7975         if (argCount < ft->GetNumParameters() && ft->GetParameterDefault(argCount) == NULL)
7976             continue;
7977 
7978         // Success
7979         ret.push_back(candidateFunctions[i]);
7980     }
7981     return ret;
7982 }
7983 
lArgIsPointerType(const Type * type)7984 static bool lArgIsPointerType(const Type *type) {
7985     if (CastType<PointerType>(type) != NULL)
7986         return true;
7987 
7988     const ReferenceType *rt = CastType<ReferenceType>(type);
7989     if (rt == NULL)
7990         return false;
7991 
7992     const Type *t = rt->GetReferenceTarget();
7993     return (CastType<PointerType>(t) != NULL);
7994 }
7995 
7996 /** This function computes the value of a cost function that represents the
7997     cost of calling a function of the given type with arguments of the
7998     given types.  If it's not possible to call the function, regardless of
7999     any type conversions applied, a cost of -1 is returned.
8000  */
computeOverloadCost(const FunctionType * ftype,const std::vector<const Type * > & argTypes,const std::vector<bool> * argCouldBeNULL,const std::vector<bool> * argIsConstant,int * cost)8001 int FunctionSymbolExpr::computeOverloadCost(const FunctionType *ftype, const std::vector<const Type *> &argTypes,
8002                                             const std::vector<bool> *argCouldBeNULL,
8003                                             const std::vector<bool> *argIsConstant, int *cost) {
8004     int costSum = 0;
8005 
8006     // In computing the cost function, we only worry about the actual
8007     // argument types--using function default parameter values is free for
8008     // the purposes here...
8009     for (int i = 0; i < (int)argTypes.size(); ++i) {
8010         cost[i] = 0;
8011         // The cost imposed by this argument will be a multiple of
8012         // costScale, which has a value set so that for each of the cost
8013         // buckets, even if all of the function arguments undergo the next
8014         // lower-cost conversion, the sum of their costs will be less than
8015         // a single instance of the next higher-cost conversion.
8016         int costScale = argTypes.size() + 1;
8017 
8018         const Type *fargType = ftype->GetParameterType(i);
8019         const Type *callType = argTypes[i];
8020 
8021         if (Type::Equal(callType, fargType))
8022             // Perfect match: no cost
8023             // Step "1" from documentation
8024             cost[i] += 0;
8025         else if (argCouldBeNULL && (*argCouldBeNULL)[i] && lArgIsPointerType(fargType))
8026             // Passing NULL to a pointer-typed parameter is also a no-cost operation
8027             // Step "1" from documentation
8028             cost[i] += 0;
8029         else {
8030             // If the argument is a compile-time constant, we'd like to
8031             // count the cost of various conversions as much lower than the
8032             // cost if it wasn't--so scale up the cost when this isn't the
8033             // case..
8034             if (argIsConstant == NULL || (*argIsConstant)[i] == false)
8035                 costScale *= 512;
8036 
8037             if (CastType<ReferenceType>(fargType)) {
8038                 // Here we completely handle the case where fargType is reference.
8039                 if (callType->IsConstType() && !fargType->IsConstType()) {
8040                     // It is forbidden to pass const object to non-const reference (cvf -> vfr)
8041                     return -1;
8042                 }
8043                 if (!callType->IsConstType() && fargType->IsConstType()) {
8044                     // It is possible to pass (vf -> cvfr)
8045                     // but it is worse than (vf -> vfr) or (cvf -> cvfr)
8046                     // Step "3" from documentation
8047                     cost[i] += 2 * costScale;
8048                 }
8049                 if (!Type::Equal(callType->GetReferenceTarget()->GetAsNonConstType(),
8050                                  fargType->GetReferenceTarget()->GetAsNonConstType())) {
8051                     // Types under references must be equal completely.
8052                     // vd -> vfr or vd -> cvfr are forbidden. (Although clang allows vd -> cvfr case.)
8053                     return -1;
8054                 }
8055                 // penalty for equal types under reference (vf -> vfr is worse than vf -> vf)
8056                 // Step "2" from documentation
8057                 cost[i] += 2 * costScale;
8058                 continue;
8059             }
8060             const Type *callTypeNP = callType;
8061             if (CastType<ReferenceType>(callType)) {
8062                 callTypeNP = callType->GetReferenceTarget();
8063                 // we can treat vfr as vf for callType with some penalty
8064                 // Step "5" from documentation
8065                 cost[i] += 2 * costScale;
8066             }
8067 
8068             // Now we deal with references, so we can normalize to non-const types
8069             // because we're passing by value anyway, so const doesn't matter.
8070             const Type *callTypeNC = callTypeNP->GetAsNonConstType();
8071             const Type *fargTypeNC = fargType->GetAsNonConstType();
8072 
8073             // Now we forget about constants and references!
8074             if (Type::EqualIgnoringConst(callTypeNP, fargType)) {
8075                 // The best case: vf -> vf.
8076                 // Step "4" from documentation
8077                 cost[i] += 1 * costScale;
8078                 continue;
8079             }
8080             if (lIsMatchWithTypeWidening(callTypeNC, fargTypeNC)) {
8081                 // A little bit worse case: vf -> vd.
8082                 // Step "6" from documentation
8083                 cost[i] += 8 * costScale;
8084                 continue;
8085             }
8086             if (fargType->IsVaryingType() && callType->IsUniformType()) {
8087                 // Here we deal with brodcasting uniform to varying.
8088                 // callType - varying and fargType - uniform is forbidden.
8089                 if (Type::Equal(callTypeNC->GetAsVaryingType(), fargTypeNC)) {
8090                     // uf -> vf is better than uf -> ui or uf -> ud
8091                     // Step "7" from documentation
8092                     cost[i] += 16 * costScale;
8093                     continue;
8094                 }
8095                 if (lIsMatchWithTypeWidening(callTypeNC->GetAsVaryingType(), fargTypeNC)) {
8096                     // uf -> vd is better than uf -> vi (128 < 128 + 64)
8097                     // but worse than uf -> ui (128 > 64)
8098                     // Step "9" from documentation
8099                     cost[i] += 128 * costScale;
8100                     continue;
8101                 }
8102                 // 128 + 64 is the max. uf -> vi is the worst case.
8103                 // Step "10" from documentation
8104                 cost[i] += 128 * costScale;
8105             }
8106             if (CanConvertTypes(callTypeNC, fargTypeNC))
8107                 // two cases: the worst is 128 + 64: uf -> vi and
8108                 // the only 64: (64 < 128) uf -> ui worse than uf -> vd
8109                 // Step "8" from documentation
8110                 cost[i] += 64 * costScale;
8111             else
8112                 // Failure--no type conversion possible...
8113                 return -1;
8114         }
8115     }
8116 
8117     for (int i = 0; i < (int)argTypes.size(); ++i) {
8118         costSum = costSum + cost[i];
8119     }
8120     return costSum;
8121 }
8122 
ResolveOverloads(SourcePos argPos,const std::vector<const Type * > & argTypes,const std::vector<bool> * argCouldBeNULL,const std::vector<bool> * argIsConstant)8123 bool FunctionSymbolExpr::ResolveOverloads(SourcePos argPos, const std::vector<const Type *> &argTypes,
8124                                           const std::vector<bool> *argCouldBeNULL,
8125                                           const std::vector<bool> *argIsConstant) {
8126     const char *funName = candidateFunctions.front()->name.c_str();
8127     if (triedToResolve == true) {
8128         return true;
8129     }
8130 
8131     triedToResolve = true;
8132 
8133     // Functions with names that start with "__" should only be various
8134     // builtins.  For those, we'll demand an exact match, since we'll
8135     // expect whichever function in stdlib.ispc is calling out to one of
8136     // those to be matching the argument types exactly; this is to be a bit
8137     // extra safe to be sure that the expected builtin is in fact being
8138     // called.
8139     bool exactMatchOnly = (name.substr(0, 2) == "__");
8140 
8141     // First, find the subset of overload candidates that take the same
8142     // number of arguments as have parameters (including functions that
8143     // take more arguments but have defaults starting no later than after
8144     // our last parameter).
8145     std::vector<Symbol *> actualCandidates = getCandidateFunctions(argTypes.size());
8146 
8147     int bestMatchCost = 1 << 30;
8148     std::vector<Symbol *> matches;
8149     std::vector<int> candidateCosts;
8150     std::vector<int *> candidateExpandCosts;
8151 
8152     if (actualCandidates.size() == 0)
8153         goto failure;
8154 
8155     // Compute the cost for calling each of the candidate functions
8156     for (int i = 0; i < (int)actualCandidates.size(); ++i) {
8157         const FunctionType *ft = CastType<FunctionType>(actualCandidates[i]->type);
8158         AssertPos(pos, ft != NULL);
8159         int *cost = new int[argTypes.size()];
8160         candidateCosts.push_back(computeOverloadCost(ft, argTypes, argCouldBeNULL, argIsConstant, cost));
8161         candidateExpandCosts.push_back(cost);
8162     }
8163 
8164     // Find the best cost, and then the candidate or candidates that have
8165     // that cost.
8166     for (int i = 0; i < (int)candidateCosts.size(); ++i) {
8167         if (candidateCosts[i] != -1 && candidateCosts[i] < bestMatchCost)
8168             bestMatchCost = candidateCosts[i];
8169     }
8170     // None of the candidates matched
8171     if (bestMatchCost == (1 << 30))
8172         goto failure;
8173     for (int i = 0; i < (int)candidateCosts.size(); ++i) {
8174         if (candidateCosts[i] == bestMatchCost) {
8175             for (int j = 0; j < (int)candidateCosts.size(); ++j) {
8176                 for (int k = 0; k < argTypes.size(); k++) {
8177                     if (candidateCosts[j] != -1 && candidateExpandCosts[j][k] < candidateExpandCosts[i][k]) {
8178                         std::vector<Symbol *> temp;
8179                         temp.push_back(actualCandidates[i]);
8180                         temp.push_back(actualCandidates[j]);
8181                         std::string candidateMessage = lGetOverloadCandidateMessage(temp, argTypes, argCouldBeNULL);
8182                         Warning(pos,
8183                                 "call to \"%s\" is ambiguous. "
8184                                 "This warning will be turned into error in the next ispc release.\n"
8185                                 "Please add explicit cast to arguments to have unambiguous match."
8186                                 "\n%s",
8187                                 funName, candidateMessage.c_str());
8188                     }
8189                 }
8190             }
8191             matches.push_back(actualCandidates[i]);
8192         }
8193     }
8194     for (int i = 0; i < (int)candidateExpandCosts.size(); ++i) {
8195         delete[] candidateExpandCosts[i];
8196     }
8197 
8198     if (matches.size() == 1) {
8199         // Only one match: success
8200         matchingFunc = matches[0];
8201         return true;
8202     } else if (matches.size() > 1) {
8203         // Multiple matches: ambiguous
8204         std::string candidateMessage = lGetOverloadCandidateMessage(matches, argTypes, argCouldBeNULL);
8205         Error(pos,
8206               "Multiple overloaded functions matched call to function "
8207               "\"%s\"%s.\n%s",
8208               funName, exactMatchOnly ? " only considering exact matches" : "", candidateMessage.c_str());
8209         return false;
8210     } else {
8211         // No matches at all
8212     failure:
8213         std::string candidateMessage = lGetOverloadCandidateMessage(matches, argTypes, argCouldBeNULL);
8214         Error(pos,
8215               "Unable to find any matching overload for call to function "
8216               "\"%s\"%s.\n%s",
8217               funName, exactMatchOnly ? " only considering exact matches" : "", candidateMessage.c_str());
8218         return false;
8219     }
8220 }
8221 
GetMatchingFunction()8222 Symbol *FunctionSymbolExpr::GetMatchingFunction() { return matchingFunc; }
8223 
8224 ///////////////////////////////////////////////////////////////////////////
8225 // SyncExpr
8226 
GetType() const8227 const Type *SyncExpr::GetType() const { return AtomicType::Void; }
8228 
GetValue(FunctionEmitContext * ctx) const8229 llvm::Value *SyncExpr::GetValue(FunctionEmitContext *ctx) const {
8230     ctx->SetDebugPos(pos);
8231     ctx->SyncInst();
8232     return NULL;
8233 }
8234 
EstimateCost() const8235 int SyncExpr::EstimateCost() const { return COST_SYNC; }
8236 
Print() const8237 void SyncExpr::Print() const {
8238     printf("sync");
8239     pos.Print();
8240 }
8241 
TypeCheck()8242 Expr *SyncExpr::TypeCheck() { return this; }
8243 
Optimize()8244 Expr *SyncExpr::Optimize() { return this; }
8245 
8246 ///////////////////////////////////////////////////////////////////////////
8247 // NullPointerExpr
8248 
GetValue(FunctionEmitContext * ctx) const8249 llvm::Value *NullPointerExpr::GetValue(FunctionEmitContext *ctx) const {
8250     return llvm::ConstantPointerNull::get(LLVMTypes::VoidPointerType);
8251 }
8252 
GetType() const8253 const Type *NullPointerExpr::GetType() const { return PointerType::Void; }
8254 
TypeCheck()8255 Expr *NullPointerExpr::TypeCheck() { return this; }
8256 
Optimize()8257 Expr *NullPointerExpr::Optimize() { return this; }
8258 
GetConstant(const Type * type) const8259 std::pair<llvm::Constant *, bool> NullPointerExpr::GetConstant(const Type *type) const {
8260     const PointerType *pt = CastType<PointerType>(type);
8261     if (pt == NULL)
8262         return std::pair<llvm::Constant *, bool>(NULL, false);
8263 
8264     llvm::Type *llvmType = type->LLVMType(g->ctx);
8265     if (llvmType == NULL) {
8266         AssertPos(pos, m->errorCount > 0);
8267         return std::pair<llvm::Constant *, bool>(NULL, false);
8268     }
8269 
8270     return std::pair<llvm::Constant *, bool>(llvm::Constant::getNullValue(llvmType), false);
8271 }
8272 
Print() const8273 void NullPointerExpr::Print() const {
8274     printf("NULL");
8275     pos.Print();
8276 }
8277 
EstimateCost() const8278 int NullPointerExpr::EstimateCost() const { return 0; }
8279 
8280 ///////////////////////////////////////////////////////////////////////////
8281 // NewExpr
8282 
NewExpr(int typeQual,const Type * t,Expr * init,Expr * count,SourcePos tqPos,SourcePos p)8283 NewExpr::NewExpr(int typeQual, const Type *t, Expr *init, Expr *count, SourcePos tqPos, SourcePos p)
8284     : Expr(p, NewExprID) {
8285     allocType = t;
8286 
8287     initExpr = init;
8288     countExpr = count;
8289 
8290     /* (The below cases actually should be impossible, since the parser
8291        doesn't allow more than a single type qualifier before a "new".) */
8292     if ((typeQual & ~(TYPEQUAL_UNIFORM | TYPEQUAL_VARYING)) != 0) {
8293         Error(tqPos, "Illegal type qualifiers in \"new\" expression (only "
8294                      "\"uniform\" and \"varying\" are allowed.");
8295         isVarying = false;
8296     } else if ((typeQual & TYPEQUAL_UNIFORM) != 0 && (typeQual & TYPEQUAL_VARYING) != 0) {
8297         Error(tqPos, "Illegal to provide both \"uniform\" and \"varying\" "
8298                      "qualifiers to \"new\" expression.");
8299         isVarying = false;
8300     } else
8301         // If no type qualifier is given before the 'new', treat it as a
8302         // varying new.
8303         isVarying = (typeQual == 0) || (typeQual & TYPEQUAL_VARYING);
8304 
8305     if (allocType != NULL)
8306         allocType = allocType->ResolveUnboundVariability(Variability::Uniform);
8307 }
8308 
GetValue(FunctionEmitContext * ctx) const8309 llvm::Value *NewExpr::GetValue(FunctionEmitContext *ctx) const {
8310     bool do32Bit = (g->target->is32Bit() || g->opt.force32BitAddressing);
8311 
8312     // Determine how many elements we need to allocate.  Note that this
8313     // will be a varying value if this is a varying new.
8314     llvm::Value *countValue;
8315     if (countExpr != NULL) {
8316         countValue = countExpr->GetValue(ctx);
8317         if (countValue == NULL) {
8318             AssertPos(pos, m->errorCount > 0);
8319             return NULL;
8320         }
8321     } else {
8322         if (isVarying) {
8323             if (do32Bit)
8324                 countValue = LLVMInt32Vector(1);
8325             else
8326                 countValue = LLVMInt64Vector(1);
8327         } else {
8328             if (do32Bit)
8329                 countValue = LLVMInt32(1);
8330             else
8331                 countValue = LLVMInt64(1);
8332         }
8333     }
8334 
8335     // Compute the total amount of memory to allocate, allocSize, as the
8336     // product of the number of elements to allocate and the size of a
8337     // single element.
8338     llvm::Value *eltSize = g->target->SizeOf(allocType->LLVMType(g->ctx), ctx->GetCurrentBasicBlock());
8339     if (isVarying)
8340         eltSize = ctx->SmearUniform(eltSize, "smear_size");
8341     llvm::Value *allocSize = ctx->BinaryOperator(llvm::Instruction::Mul, countValue, eltSize, "alloc_size");
8342 
8343     // Determine which allocation builtin function to call: uniform or
8344     // varying, and taking 32-bit or 64-bit allocation counts.
8345     llvm::Function *func;
8346     if (isVarying) {
8347         if (g->target->is32Bit()) {
8348             func = m->module->getFunction("__new_varying32_32rt");
8349         } else if (g->opt.force32BitAddressing) {
8350             func = m->module->getFunction("__new_varying32_64rt");
8351         } else {
8352             func = m->module->getFunction("__new_varying64_64rt");
8353         }
8354     } else {
8355         // FIXME: __new_uniform_32rt should take i32
8356         if (allocSize->getType() != LLVMTypes::Int64Type)
8357             allocSize = ctx->SExtInst(allocSize, LLVMTypes::Int64Type, "alloc_size64");
8358         if (g->target->is32Bit()) {
8359             func = m->module->getFunction("__new_uniform_32rt");
8360         } else {
8361             func = m->module->getFunction("__new_uniform_64rt");
8362         }
8363     }
8364     AssertPos(pos, func != NULL);
8365 
8366     // Make the call for the the actual allocation.
8367     llvm::Value *ptrValue = ctx->CallInst(func, NULL, allocSize, "new");
8368 
8369     // Now handle initializers and returning the right type for the result.
8370     const Type *retType = GetType();
8371     if (retType == NULL)
8372         return NULL;
8373     if (isVarying) {
8374         if (g->target->is32Bit())
8375             // Convert i64 vector values to i32 if we are compiling to a
8376             // 32-bit target.
8377             ptrValue = ctx->TruncInst(ptrValue, LLVMTypes::VoidPointerVectorType, "ptr_to_32bit");
8378 
8379         if (initExpr != NULL) {
8380             // If we have an initializer expression, emit code that checks
8381             // to see if each lane is active and if so, runs the code to do
8382             // the initialization.  Note that we're we're taking advantage
8383             // of the fact that the __new_varying*() functions are
8384             // implemented to return NULL for program instances that aren't
8385             // executing; more generally, we should be using the current
8386             // execution mask for this...
8387             for (int i = 0; i < g->target->getVectorWidth(); ++i) {
8388                 llvm::BasicBlock *bbInit = ctx->CreateBasicBlock("init_ptr");
8389                 llvm::BasicBlock *bbSkip = ctx->CreateBasicBlock("skip_init");
8390                 llvm::Value *p = ctx->ExtractInst(ptrValue, i);
8391                 llvm::Value *nullValue = g->target->is32Bit() ? LLVMInt32(0) : LLVMInt64(0);
8392                 // Is the pointer for the current lane non-zero?
8393                 llvm::Value *nonNull =
8394                     ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE, p, nullValue, "non_null");
8395                 ctx->BranchInst(bbInit, bbSkip, nonNull);
8396 
8397                 // Initialize the memory pointed to by the pointer for the
8398                 // current lane.
8399                 ctx->SetCurrentBasicBlock(bbInit);
8400                 llvm::Type *ptrType = retType->GetAsUniformType()->LLVMType(g->ctx);
8401                 llvm::Value *ptr = ctx->IntToPtrInst(p, ptrType);
8402                 InitSymbol(ptr, allocType, initExpr, ctx, pos);
8403                 ctx->BranchInst(bbSkip);
8404 
8405                 ctx->SetCurrentBasicBlock(bbSkip);
8406             }
8407         }
8408 
8409         return ptrValue;
8410     } else {
8411         // For uniform news, we just need to cast the void * to be a
8412         // pointer of the return type and to run the code for initializers,
8413         // if present.
8414         llvm::Type *ptrType = retType->LLVMType(g->ctx);
8415         ptrValue = ctx->BitCastInst(ptrValue, ptrType, llvm::Twine(ptrValue->getName()) + "_cast_ptr");
8416 
8417         if (initExpr != NULL)
8418             InitSymbol(ptrValue, allocType, initExpr, ctx, pos);
8419 
8420         return ptrValue;
8421     }
8422 }
8423 
GetType() const8424 const Type *NewExpr::GetType() const {
8425     if (allocType == NULL)
8426         return NULL;
8427 
8428     return isVarying ? PointerType::GetVarying(allocType) : PointerType::GetUniform(allocType);
8429 }
8430 
TypeCheck()8431 Expr *NewExpr::TypeCheck() {
8432     // It's illegal to call new with an undefined struct type
8433     if (allocType == NULL) {
8434         AssertPos(pos, m->errorCount > 0);
8435         return NULL;
8436     }
8437 
8438     if (g->target->isGenXTarget()) {
8439         Error(pos, "\"new\" is not supported for genx-* targets yet.");
8440         return NULL;
8441     }
8442 
8443     if (CastType<UndefinedStructType>(allocType) != NULL) {
8444         Error(pos,
8445               "Can't dynamically allocate storage for declared "
8446               "but not defined type \"%s\".",
8447               allocType->GetString().c_str());
8448         return NULL;
8449     }
8450     const StructType *st = CastType<StructType>(allocType);
8451     if (st != NULL && !st->IsDefined()) {
8452         Error(pos,
8453               "Can't dynamically allocate storage for declared "
8454               "type \"%s\" containing undefined member type.",
8455               allocType->GetString().c_str());
8456         return NULL;
8457     }
8458 
8459     // Otherwise we only need to make sure that if we have an expression
8460     // giving a number of elements to allocate that it can be converted to
8461     // an integer of the appropriate variability.
8462     if (countExpr == NULL)
8463         return this;
8464 
8465     const Type *countType;
8466     if ((countType = countExpr->GetType()) == NULL)
8467         return NULL;
8468 
8469     if (isVarying == false && countType->IsVaryingType()) {
8470         Error(pos, "Illegal to provide \"varying\" allocation count with "
8471                    "\"uniform new\" expression.");
8472         return NULL;
8473     }
8474 
8475     // Figure out the type that the allocation count should be
8476     const Type *t =
8477         (g->target->is32Bit() || g->opt.force32BitAddressing) ? AtomicType::UniformUInt32 : AtomicType::UniformUInt64;
8478     if (isVarying)
8479         t = t->GetAsVaryingType();
8480 
8481     countExpr = TypeConvertExpr(countExpr, t, "item count");
8482     if (countExpr == NULL)
8483         return NULL;
8484 
8485     return this;
8486 }
8487 
Optimize()8488 Expr *NewExpr::Optimize() { return this; }
8489 
Print() const8490 void NewExpr::Print() const { printf("new (%s)", allocType ? allocType->GetString().c_str() : "NULL"); }
8491 
EstimateCost() const8492 int NewExpr::EstimateCost() const { return COST_NEW; }
8493