1 //
2 // Copyright 2016 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // ConstantUnion: Constant folding helper class.
7 
8 #include "compiler/translator/ConstantUnion.h"
9 
10 #include "common/mathutil.h"
11 #include "compiler/translator/Diagnostics.h"
12 
13 namespace sh
14 {
15 
16 namespace
17 {
18 
CheckedSum(float lhs,float rhs,TDiagnostics * diag,const TSourceLoc & line)19 float CheckedSum(float lhs, float rhs, TDiagnostics *diag, const TSourceLoc &line)
20 {
21     float result = lhs + rhs;
22     if (gl::isNaN(result) && !gl::isNaN(lhs) && !gl::isNaN(rhs))
23     {
24         diag->warning(line, "Constant folded undefined addition generated NaN", "+");
25     }
26     else if (gl::isInf(result) && !gl::isInf(lhs) && !gl::isInf(rhs))
27     {
28         diag->warning(line, "Constant folded addition overflowed to infinity", "+");
29     }
30     return result;
31 }
32 
CheckedDiff(float lhs,float rhs,TDiagnostics * diag,const TSourceLoc & line)33 float CheckedDiff(float lhs, float rhs, TDiagnostics *diag, const TSourceLoc &line)
34 {
35     float result = lhs - rhs;
36     if (gl::isNaN(result) && !gl::isNaN(lhs) && !gl::isNaN(rhs))
37     {
38         diag->warning(line, "Constant folded undefined subtraction generated NaN", "-");
39     }
40     else if (gl::isInf(result) && !gl::isInf(lhs) && !gl::isInf(rhs))
41     {
42         diag->warning(line, "Constant folded subtraction overflowed to infinity", "-");
43     }
44     return result;
45 }
46 
CheckedMul(float lhs,float rhs,TDiagnostics * diag,const TSourceLoc & line)47 float CheckedMul(float lhs, float rhs, TDiagnostics *diag, const TSourceLoc &line)
48 {
49     float result = lhs * rhs;
50     if (gl::isNaN(result) && !gl::isNaN(lhs) && !gl::isNaN(rhs))
51     {
52         diag->warning(line, "Constant folded undefined multiplication generated NaN", "*");
53     }
54     else if (gl::isInf(result) && !gl::isInf(lhs) && !gl::isInf(rhs))
55     {
56         diag->warning(line, "Constant folded multiplication overflowed to infinity", "*");
57     }
58     return result;
59 }
60 
IsValidShiftOffset(const TConstantUnion & rhs)61 bool IsValidShiftOffset(const TConstantUnion &rhs)
62 {
63     return (rhs.getType() == EbtInt && (rhs.getIConst() >= 0 && rhs.getIConst() <= 31)) ||
64            (rhs.getType() == EbtUInt && rhs.getUConst() <= 31u);
65 }
66 
67 }  // anonymous namespace
68 
TConstantUnion()69 TConstantUnion::TConstantUnion()
70 {
71     iConst = 0;
72     type   = EbtVoid;
73 }
74 
getIConst() const75 int TConstantUnion::getIConst() const
76 {
77     ASSERT(type == EbtInt);
78     return iConst;
79 }
80 
getUConst() const81 unsigned int TConstantUnion::getUConst() const
82 {
83     ASSERT(type == EbtUInt);
84     return uConst;
85 }
86 
getFConst() const87 float TConstantUnion::getFConst() const
88 {
89     ASSERT(type == EbtFloat);
90     return fConst;
91 }
92 
getBConst() const93 bool TConstantUnion::getBConst() const
94 {
95     ASSERT(type == EbtBool);
96     return bConst;
97 }
98 
getYuvCscStandardEXTConst() const99 TYuvCscStandardEXT TConstantUnion::getYuvCscStandardEXTConst() const
100 {
101     ASSERT(type == EbtYuvCscStandardEXT);
102     return yuvCscStandardEXTConst;
103 }
104 
cast(TBasicType newType,const TConstantUnion & constant)105 bool TConstantUnion::cast(TBasicType newType, const TConstantUnion &constant)
106 {
107     switch (newType)
108     {
109         case EbtFloat:
110             switch (constant.type)
111             {
112                 case EbtInt:
113                     setFConst(static_cast<float>(constant.getIConst()));
114                     break;
115                 case EbtUInt:
116                     setFConst(static_cast<float>(constant.getUConst()));
117                     break;
118                 case EbtBool:
119                     setFConst(static_cast<float>(constant.getBConst()));
120                     break;
121                 case EbtFloat:
122                     setFConst(static_cast<float>(constant.getFConst()));
123                     break;
124                 default:
125                     return false;
126             }
127             break;
128         case EbtInt:
129             switch (constant.type)
130             {
131                 case EbtInt:
132                     setIConst(static_cast<int>(constant.getIConst()));
133                     break;
134                 case EbtUInt:
135                     setIConst(static_cast<int>(constant.getUConst()));
136                     break;
137                 case EbtBool:
138                     setIConst(static_cast<int>(constant.getBConst()));
139                     break;
140                 case EbtFloat:
141                     setIConst(static_cast<int>(constant.getFConst()));
142                     break;
143                 default:
144                     return false;
145             }
146             break;
147         case EbtUInt:
148             switch (constant.type)
149             {
150                 case EbtInt:
151                     setUConst(static_cast<unsigned int>(constant.getIConst()));
152                     break;
153                 case EbtUInt:
154                     setUConst(static_cast<unsigned int>(constant.getUConst()));
155                     break;
156                 case EbtBool:
157                     setUConst(static_cast<unsigned int>(constant.getBConst()));
158                     break;
159                 case EbtFloat:
160                     setUConst(static_cast<unsigned int>(constant.getFConst()));
161                     break;
162                 default:
163                     return false;
164             }
165             break;
166         case EbtBool:
167             switch (constant.type)
168             {
169                 case EbtInt:
170                     setBConst(constant.getIConst() != 0);
171                     break;
172                 case EbtUInt:
173                     setBConst(constant.getUConst() != 0);
174                     break;
175                 case EbtBool:
176                     setBConst(constant.getBConst());
177                     break;
178                 case EbtFloat:
179                     setBConst(constant.getFConst() != 0.0f);
180                     break;
181                 default:
182                     return false;
183             }
184             break;
185         case EbtStruct:  // Struct fields don't get cast
186             switch (constant.type)
187             {
188                 case EbtInt:
189                     setIConst(constant.getIConst());
190                     break;
191                 case EbtUInt:
192                     setUConst(constant.getUConst());
193                     break;
194                 case EbtBool:
195                     setBConst(constant.getBConst());
196                     break;
197                 case EbtFloat:
198                     setFConst(constant.getFConst());
199                     break;
200                 default:
201                     return false;
202             }
203             break;
204         default:
205             return false;
206     }
207 
208     return true;
209 }
210 
operator ==(const int i) const211 bool TConstantUnion::operator==(const int i) const
212 {
213     return i == iConst;
214 }
215 
operator ==(const unsigned int u) const216 bool TConstantUnion::operator==(const unsigned int u) const
217 {
218     return u == uConst;
219 }
220 
operator ==(const float f) const221 bool TConstantUnion::operator==(const float f) const
222 {
223     return f == fConst;
224 }
225 
operator ==(const bool b) const226 bool TConstantUnion::operator==(const bool b) const
227 {
228     return b == bConst;
229 }
230 
operator ==(const TYuvCscStandardEXT s) const231 bool TConstantUnion::operator==(const TYuvCscStandardEXT s) const
232 {
233     return s == yuvCscStandardEXTConst;
234 }
235 
operator ==(const TConstantUnion & constant) const236 bool TConstantUnion::operator==(const TConstantUnion &constant) const
237 {
238     if (constant.type != type)
239         return false;
240 
241     switch (type)
242     {
243         case EbtInt:
244             return constant.iConst == iConst;
245         case EbtUInt:
246             return constant.uConst == uConst;
247         case EbtFloat:
248             return constant.fConst == fConst;
249         case EbtBool:
250             return constant.bConst == bConst;
251         case EbtYuvCscStandardEXT:
252             return constant.yuvCscStandardEXTConst == yuvCscStandardEXTConst;
253         default:
254             return false;
255     }
256 }
257 
operator !=(const int i) const258 bool TConstantUnion::operator!=(const int i) const
259 {
260     return !operator==(i);
261 }
262 
operator !=(const unsigned int u) const263 bool TConstantUnion::operator!=(const unsigned int u) const
264 {
265     return !operator==(u);
266 }
267 
operator !=(const float f) const268 bool TConstantUnion::operator!=(const float f) const
269 {
270     return !operator==(f);
271 }
272 
operator !=(const bool b) const273 bool TConstantUnion::operator!=(const bool b) const
274 {
275     return !operator==(b);
276 }
277 
operator !=(const TYuvCscStandardEXT s) const278 bool TConstantUnion::operator!=(const TYuvCscStandardEXT s) const
279 {
280     return !operator==(s);
281 }
282 
operator !=(const TConstantUnion & constant) const283 bool TConstantUnion::operator!=(const TConstantUnion &constant) const
284 {
285     return !operator==(constant);
286 }
287 
operator >(const TConstantUnion & constant) const288 bool TConstantUnion::operator>(const TConstantUnion &constant) const
289 {
290     ASSERT(type == constant.type);
291     switch (type)
292     {
293         case EbtInt:
294             return iConst > constant.iConst;
295         case EbtUInt:
296             return uConst > constant.uConst;
297         case EbtFloat:
298             return fConst > constant.fConst;
299         default:
300             return false;  // Invalid operation, handled at semantic analysis
301     }
302 }
303 
operator <(const TConstantUnion & constant) const304 bool TConstantUnion::operator<(const TConstantUnion &constant) const
305 {
306     ASSERT(type == constant.type);
307     switch (type)
308     {
309         case EbtInt:
310             return iConst < constant.iConst;
311         case EbtUInt:
312             return uConst < constant.uConst;
313         case EbtFloat:
314             return fConst < constant.fConst;
315         default:
316             return false;  // Invalid operation, handled at semantic analysis
317     }
318 }
319 
320 // static
add(const TConstantUnion & lhs,const TConstantUnion & rhs,TDiagnostics * diag,const TSourceLoc & line)321 TConstantUnion TConstantUnion::add(const TConstantUnion &lhs,
322                                    const TConstantUnion &rhs,
323                                    TDiagnostics *diag,
324                                    const TSourceLoc &line)
325 {
326     TConstantUnion returnValue;
327     ASSERT(lhs.type == rhs.type);
328     switch (lhs.type)
329     {
330         case EbtInt:
331             returnValue.setIConst(gl::WrappingSum<int>(lhs.iConst, rhs.iConst));
332             break;
333         case EbtUInt:
334             returnValue.setUConst(gl::WrappingSum<unsigned int>(lhs.uConst, rhs.uConst));
335             break;
336         case EbtFloat:
337             returnValue.setFConst(CheckedSum(lhs.fConst, rhs.fConst, diag, line));
338             break;
339         default:
340             UNREACHABLE();
341     }
342 
343     return returnValue;
344 }
345 
346 // static
sub(const TConstantUnion & lhs,const TConstantUnion & rhs,TDiagnostics * diag,const TSourceLoc & line)347 TConstantUnion TConstantUnion::sub(const TConstantUnion &lhs,
348                                    const TConstantUnion &rhs,
349                                    TDiagnostics *diag,
350                                    const TSourceLoc &line)
351 {
352     TConstantUnion returnValue;
353     ASSERT(lhs.type == rhs.type);
354     switch (lhs.type)
355     {
356         case EbtInt:
357             returnValue.setIConst(gl::WrappingDiff<int>(lhs.iConst, rhs.iConst));
358             break;
359         case EbtUInt:
360             returnValue.setUConst(gl::WrappingDiff<unsigned int>(lhs.uConst, rhs.uConst));
361             break;
362         case EbtFloat:
363             returnValue.setFConst(CheckedDiff(lhs.fConst, rhs.fConst, diag, line));
364             break;
365         default:
366             UNREACHABLE();
367     }
368 
369     return returnValue;
370 }
371 
372 // static
mul(const TConstantUnion & lhs,const TConstantUnion & rhs,TDiagnostics * diag,const TSourceLoc & line)373 TConstantUnion TConstantUnion::mul(const TConstantUnion &lhs,
374                                    const TConstantUnion &rhs,
375                                    TDiagnostics *diag,
376                                    const TSourceLoc &line)
377 {
378     TConstantUnion returnValue;
379     ASSERT(lhs.type == rhs.type);
380     switch (lhs.type)
381     {
382         case EbtInt:
383             returnValue.setIConst(gl::WrappingMul(lhs.iConst, rhs.iConst));
384             break;
385         case EbtUInt:
386             // Unsigned integer math in C++ is defined to be done in modulo 2^n, so we rely on that
387             // to implement wrapping multiplication.
388             returnValue.setUConst(lhs.uConst * rhs.uConst);
389             break;
390         case EbtFloat:
391             returnValue.setFConst(CheckedMul(lhs.fConst, rhs.fConst, diag, line));
392             break;
393         default:
394             UNREACHABLE();
395     }
396 
397     return returnValue;
398 }
399 
operator %(const TConstantUnion & constant) const400 TConstantUnion TConstantUnion::operator%(const TConstantUnion &constant) const
401 {
402     TConstantUnion returnValue;
403     ASSERT(type == constant.type);
404     switch (type)
405     {
406         case EbtInt:
407             returnValue.setIConst(iConst % constant.iConst);
408             break;
409         case EbtUInt:
410             returnValue.setUConst(uConst % constant.uConst);
411             break;
412         default:
413             UNREACHABLE();
414     }
415 
416     return returnValue;
417 }
418 
419 // static
rshift(const TConstantUnion & lhs,const TConstantUnion & rhs,TDiagnostics * diag,const TSourceLoc & line)420 TConstantUnion TConstantUnion::rshift(const TConstantUnion &lhs,
421                                       const TConstantUnion &rhs,
422                                       TDiagnostics *diag,
423                                       const TSourceLoc &line)
424 {
425     TConstantUnion returnValue;
426     ASSERT(lhs.type == EbtInt || lhs.type == EbtUInt);
427     ASSERT(rhs.type == EbtInt || rhs.type == EbtUInt);
428     if (!IsValidShiftOffset(rhs))
429     {
430         diag->warning(line, "Undefined shift (operand out of range)", ">>");
431         switch (lhs.type)
432         {
433             case EbtInt:
434                 returnValue.setIConst(0);
435                 break;
436             case EbtUInt:
437                 returnValue.setUConst(0u);
438                 break;
439             default:
440                 UNREACHABLE();
441         }
442         return returnValue;
443     }
444 
445     switch (lhs.type)
446     {
447         case EbtInt:
448         {
449             unsigned int shiftOffset = 0;
450             switch (rhs.type)
451             {
452                 case EbtInt:
453                     shiftOffset = static_cast<unsigned int>(rhs.iConst);
454                     break;
455                 case EbtUInt:
456                     shiftOffset = rhs.uConst;
457                     break;
458                 default:
459                     UNREACHABLE();
460             }
461             if (shiftOffset > 0)
462             {
463                 // ESSL 3.00.6 section 5.9: "If E1 is a signed integer, the right-shift will extend
464                 // the sign bit." In C++ shifting negative integers is undefined, so we implement
465                 // extending the sign bit manually.
466                 int lhsSafe = lhs.iConst;
467                 if (lhsSafe == std::numeric_limits<int>::min())
468                 {
469                     // The min integer needs special treatment because only bit it has set is the
470                     // sign bit, which we clear later to implement safe right shift of negative
471                     // numbers.
472                     lhsSafe = -0x40000000;
473                     --shiftOffset;
474                 }
475                 if (shiftOffset > 0)
476                 {
477                     bool extendSignBit = false;
478                     if (lhsSafe < 0)
479                     {
480                         extendSignBit = true;
481                         // Clear the sign bit so that bitshift right is defined in C++.
482                         lhsSafe &= 0x7fffffff;
483                         ASSERT(lhsSafe > 0);
484                     }
485                     returnValue.setIConst(lhsSafe >> shiftOffset);
486 
487                     // Manually fill in the extended sign bit if necessary.
488                     if (extendSignBit)
489                     {
490                         int extendedSignBit = static_cast<int>(0xffffffffu << (31 - shiftOffset));
491                         returnValue.setIConst(returnValue.getIConst() | extendedSignBit);
492                     }
493                 }
494                 else
495                 {
496                     returnValue.setIConst(lhsSafe);
497                 }
498             }
499             else
500             {
501                 returnValue.setIConst(lhs.iConst);
502             }
503             break;
504         }
505         case EbtUInt:
506             switch (rhs.type)
507             {
508                 case EbtInt:
509                     returnValue.setUConst(lhs.uConst >> rhs.iConst);
510                     break;
511                 case EbtUInt:
512                     returnValue.setUConst(lhs.uConst >> rhs.uConst);
513                     break;
514                 default:
515                     UNREACHABLE();
516             }
517             break;
518 
519         default:
520             UNREACHABLE();
521     }
522     return returnValue;
523 }
524 
525 // static
lshift(const TConstantUnion & lhs,const TConstantUnion & rhs,TDiagnostics * diag,const TSourceLoc & line)526 TConstantUnion TConstantUnion::lshift(const TConstantUnion &lhs,
527                                       const TConstantUnion &rhs,
528                                       TDiagnostics *diag,
529                                       const TSourceLoc &line)
530 {
531     TConstantUnion returnValue;
532     ASSERT(lhs.type == EbtInt || lhs.type == EbtUInt);
533     ASSERT(rhs.type == EbtInt || rhs.type == EbtUInt);
534     if (!IsValidShiftOffset(rhs))
535     {
536         diag->warning(line, "Undefined shift (operand out of range)", "<<");
537         switch (lhs.type)
538         {
539             case EbtInt:
540                 returnValue.setIConst(0);
541                 break;
542             case EbtUInt:
543                 returnValue.setUConst(0u);
544                 break;
545             default:
546                 UNREACHABLE();
547         }
548         return returnValue;
549     }
550 
551     switch (lhs.type)
552     {
553         case EbtInt:
554             switch (rhs.type)
555             {
556                 // Cast to unsigned integer before shifting, since ESSL 3.00.6 section 5.9 says that
557                 // lhs is "interpreted as a bit pattern". This also avoids the possibility of signed
558                 // integer overflow or undefined shift of a negative integer.
559                 case EbtInt:
560                     returnValue.setIConst(
561                         static_cast<int>(static_cast<uint32_t>(lhs.iConst) << rhs.iConst));
562                     break;
563                 case EbtUInt:
564                     returnValue.setIConst(
565                         static_cast<int>(static_cast<uint32_t>(lhs.iConst) << rhs.uConst));
566                     break;
567                 default:
568                     UNREACHABLE();
569             }
570             break;
571 
572         case EbtUInt:
573             switch (rhs.type)
574             {
575                 case EbtInt:
576                     returnValue.setUConst(lhs.uConst << rhs.iConst);
577                     break;
578                 case EbtUInt:
579                     returnValue.setUConst(lhs.uConst << rhs.uConst);
580                     break;
581                 default:
582                     UNREACHABLE();
583             }
584             break;
585 
586         default:
587             UNREACHABLE();
588     }
589     return returnValue;
590 }
591 
operator &(const TConstantUnion & constant) const592 TConstantUnion TConstantUnion::operator&(const TConstantUnion &constant) const
593 {
594     TConstantUnion returnValue;
595     ASSERT(constant.type == EbtInt || constant.type == EbtUInt);
596     switch (type)
597     {
598         case EbtInt:
599             returnValue.setIConst(iConst & constant.iConst);
600             break;
601         case EbtUInt:
602             returnValue.setUConst(uConst & constant.uConst);
603             break;
604         default:
605             UNREACHABLE();
606     }
607 
608     return returnValue;
609 }
610 
operator |(const TConstantUnion & constant) const611 TConstantUnion TConstantUnion::operator|(const TConstantUnion &constant) const
612 {
613     TConstantUnion returnValue;
614     ASSERT(type == constant.type);
615     switch (type)
616     {
617         case EbtInt:
618             returnValue.setIConst(iConst | constant.iConst);
619             break;
620         case EbtUInt:
621             returnValue.setUConst(uConst | constant.uConst);
622             break;
623         default:
624             UNREACHABLE();
625     }
626 
627     return returnValue;
628 }
629 
operator ^(const TConstantUnion & constant) const630 TConstantUnion TConstantUnion::operator^(const TConstantUnion &constant) const
631 {
632     TConstantUnion returnValue;
633     ASSERT(type == constant.type);
634     switch (type)
635     {
636         case EbtInt:
637             returnValue.setIConst(iConst ^ constant.iConst);
638             break;
639         case EbtUInt:
640             returnValue.setUConst(uConst ^ constant.uConst);
641             break;
642         default:
643             UNREACHABLE();
644     }
645 
646     return returnValue;
647 }
648 
operator &&(const TConstantUnion & constant) const649 TConstantUnion TConstantUnion::operator&&(const TConstantUnion &constant) const
650 {
651     TConstantUnion returnValue;
652     ASSERT(type == constant.type);
653     switch (type)
654     {
655         case EbtBool:
656             returnValue.setBConst(bConst && constant.bConst);
657             break;
658         default:
659             UNREACHABLE();
660     }
661 
662     return returnValue;
663 }
664 
operator ||(const TConstantUnion & constant) const665 TConstantUnion TConstantUnion::operator||(const TConstantUnion &constant) const
666 {
667     TConstantUnion returnValue;
668     ASSERT(type == constant.type);
669     switch (type)
670     {
671         case EbtBool:
672             returnValue.setBConst(bConst || constant.bConst);
673             break;
674         default:
675             UNREACHABLE();
676     }
677 
678     return returnValue;
679 }
680 
681 }  // namespace sh
682