1 /*
2  * Copyright (c) 2017, 2019, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 
25 package org.graalvm.compiler.nodes.calc;
26 
27 import static jdk.vm.ci.code.CodeUtil.mask;
28 
29 import org.graalvm.compiler.core.common.calc.CanonicalCondition;
30 import org.graalvm.compiler.core.common.type.IntegerStamp;
31 import org.graalvm.compiler.core.common.type.Stamp;
32 import org.graalvm.compiler.graph.NodeClass;
33 import org.graalvm.compiler.nodeinfo.NodeInfo;
34 import org.graalvm.compiler.nodes.ConstantNode;
35 import org.graalvm.compiler.nodes.LogicConstantNode;
36 import org.graalvm.compiler.nodes.LogicNegationNode;
37 import org.graalvm.compiler.nodes.LogicNode;
38 import org.graalvm.compiler.nodes.NodeView;
39 import org.graalvm.compiler.nodes.ValueNode;
40 import org.graalvm.compiler.nodes.util.GraphUtil;
41 import org.graalvm.compiler.options.OptionValues;
42 
43 import jdk.vm.ci.code.CodeUtil;
44 import jdk.vm.ci.meta.ConstantReflectionProvider;
45 import jdk.vm.ci.meta.JavaConstant;
46 import jdk.vm.ci.meta.MetaAccessProvider;
47 import jdk.vm.ci.meta.TriState;
48 
49 /**
50  * Common super-class for "a < b" comparisons both {@linkplain IntegerLowerThanNode signed} and
51  * {@linkplain IntegerBelowNode unsigned}.
52  */
53 @NodeInfo()
54 public abstract class IntegerLowerThanNode extends CompareNode {
55     public static final NodeClass<IntegerLowerThanNode> TYPE = NodeClass.create(IntegerLowerThanNode.class);
56     private final LowerOp op;
57 
IntegerLowerThanNode(NodeClass<? extends CompareNode> c, ValueNode x, ValueNode y, LowerOp op)58     protected IntegerLowerThanNode(NodeClass<? extends CompareNode> c, ValueNode x, ValueNode y, LowerOp op) {
59         super(c, op.getCondition(), false, x, y);
60         this.op = op;
61     }
62 
getOp()63     protected LowerOp getOp() {
64         return op;
65     }
66 
67     @Override
getSucceedingStampForX(boolean negated, Stamp xStampGeneric, Stamp yStampGeneric)68     public Stamp getSucceedingStampForX(boolean negated, Stamp xStampGeneric, Stamp yStampGeneric) {
69         return getSucceedingStampForX(negated, !negated, xStampGeneric, yStampGeneric, getX(), getY());
70     }
71 
72     @Override
getSucceedingStampForY(boolean negated, Stamp xStampGeneric, Stamp yStampGeneric)73     public Stamp getSucceedingStampForY(boolean negated, Stamp xStampGeneric, Stamp yStampGeneric) {
74         return getSucceedingStampForX(!negated, !negated, yStampGeneric, xStampGeneric, getY(), getX());
75     }
76 
getSucceedingStampForX(boolean mirror, boolean strict, Stamp xStampGeneric, Stamp yStampGeneric, ValueNode forX, ValueNode forY)77     private Stamp getSucceedingStampForX(boolean mirror, boolean strict, Stamp xStampGeneric, Stamp yStampGeneric, ValueNode forX, ValueNode forY) {
78         Stamp s = getSucceedingStampForX(mirror, strict, xStampGeneric, yStampGeneric);
79         if (s != null && s.isUnrestricted()) {
80             s = null;
81         }
82         if (forY instanceof AddNode && xStampGeneric instanceof IntegerStamp) {
83             IntegerStamp xStamp = (IntegerStamp) xStampGeneric;
84             AddNode addNode = (AddNode) forY;
85             IntegerStamp aStamp = null;
86             if (addNode.getX() == forX && addNode.getY().stamp(NodeView.DEFAULT) instanceof IntegerStamp) {
87                 // x < x + a
88                 aStamp = (IntegerStamp) addNode.getY().stamp(NodeView.DEFAULT);
89             } else if (addNode.getY() == forX && addNode.getX().stamp(NodeView.DEFAULT) instanceof IntegerStamp) {
90                 // x < a + x
91                 aStamp = (IntegerStamp) addNode.getX().stamp(NodeView.DEFAULT);
92             }
93             if (aStamp != null) {
94                 IntegerStamp result = getOp().getSucceedingStampForXLowerXPlusA(mirror, strict, aStamp, xStamp);
95                 result = (IntegerStamp) xStamp.tryImproveWith(result);
96                 if (result != null) {
97                     if (s != null) {
98                         s = s.improveWith(result);
99                     } else {
100                         s = result;
101                     }
102                 }
103             }
104         }
105         return s;
106     }
107 
getSucceedingStampForX(boolean mirror, boolean strict, Stamp xStampGeneric, Stamp yStampGeneric)108     private Stamp getSucceedingStampForX(boolean mirror, boolean strict, Stamp xStampGeneric, Stamp yStampGeneric) {
109         if (xStampGeneric instanceof IntegerStamp) {
110             IntegerStamp xStamp = (IntegerStamp) xStampGeneric;
111             if (yStampGeneric instanceof IntegerStamp) {
112                 IntegerStamp yStamp = (IntegerStamp) yStampGeneric;
113                 assert yStamp.getBits() == xStamp.getBits();
114                 Stamp s = getOp().getSucceedingStampForX(xStamp, yStamp, mirror, strict);
115                 if (s != null) {
116                     return s;
117                 }
118             }
119         }
120         return null;
121     }
122 
123     @Override
tryFold(Stamp xStampGeneric, Stamp yStampGeneric)124     public TriState tryFold(Stamp xStampGeneric, Stamp yStampGeneric) {
125         return getOp().tryFold(xStampGeneric, yStampGeneric);
126     }
127 
128     public abstract static class LowerOp extends CompareOp {
129         @Override
canonical(ConstantReflectionProvider constantReflection, MetaAccessProvider metaAccess, OptionValues options, Integer smallestCompareWidth, CanonicalCondition condition, boolean unorderedIsTrue, ValueNode forX, ValueNode forY, NodeView view)130         public LogicNode canonical(ConstantReflectionProvider constantReflection, MetaAccessProvider metaAccess, OptionValues options, Integer smallestCompareWidth, CanonicalCondition condition,
131                         boolean unorderedIsTrue, ValueNode forX, ValueNode forY, NodeView view) {
132             LogicNode result = super.canonical(constantReflection, metaAccess, options, smallestCompareWidth, condition, unorderedIsTrue, forX, forY, view);
133             if (result != null) {
134                 return result;
135             }
136             LogicNode synonym = findSynonym(forX, forY, view);
137             if (synonym != null) {
138                 return synonym;
139             }
140             return null;
141         }
142 
upperBound(IntegerStamp stamp)143         protected abstract long upperBound(IntegerStamp stamp);
144 
lowerBound(IntegerStamp stamp)145         protected abstract long lowerBound(IntegerStamp stamp);
146 
compare(long a, long b)147         protected abstract int compare(long a, long b);
148 
min(long a, long b)149         protected abstract long min(long a, long b);
150 
max(long a, long b)151         protected abstract long max(long a, long b);
152 
min(long a, long b, int bits)153         protected long min(long a, long b, int bits) {
154             return min(cast(a, bits), cast(b, bits));
155         }
156 
max(long a, long b, int bits)157         protected long max(long a, long b, int bits) {
158             return max(cast(a, bits), cast(b, bits));
159         }
160 
cast(long a, int bits)161         protected abstract long cast(long a, int bits);
162 
minValue(int bits)163         protected abstract long minValue(int bits);
164 
maxValue(int bits)165         protected abstract long maxValue(int bits);
166 
forInteger(int bits, long min, long max)167         protected abstract IntegerStamp forInteger(int bits, long min, long max);
168 
getCondition()169         protected abstract CanonicalCondition getCondition();
170 
createNode(ValueNode x, ValueNode y)171         protected abstract IntegerLowerThanNode createNode(ValueNode x, ValueNode y);
172 
create(ValueNode x, ValueNode y, NodeView view)173         public LogicNode create(ValueNode x, ValueNode y, NodeView view) {
174             LogicNode result = CompareNode.tryConstantFoldPrimitive(getCondition(), x, y, false, view);
175             if (result != null) {
176                 return result;
177             } else {
178                 result = findSynonym(x, y, view);
179                 if (result != null) {
180                     return result;
181                 }
182                 return createNode(x, y);
183             }
184         }
185 
findSynonym(ValueNode forX, ValueNode forY, NodeView view)186         protected LogicNode findSynonym(ValueNode forX, ValueNode forY, NodeView view) {
187             if (GraphUtil.unproxify(forX) == GraphUtil.unproxify(forY)) {
188                 return LogicConstantNode.contradiction();
189             }
190             Stamp xStampGeneric = forX.stamp(view);
191             TriState fold = tryFold(xStampGeneric, forY.stamp(view));
192             if (fold.isTrue()) {
193                 return LogicConstantNode.tautology();
194             } else if (fold.isFalse()) {
195                 return LogicConstantNode.contradiction();
196             }
197             if (forY.stamp(view) instanceof IntegerStamp) {
198                 IntegerStamp yStamp = (IntegerStamp) forY.stamp(view);
199                 IntegerStamp xStamp = (IntegerStamp) xStampGeneric;
200                 int bits = yStamp.getBits();
201                 if (forX.isJavaConstant() && !forY.isConstant()) {
202                     // bring the constant on the right
203                     long xValue = forX.asJavaConstant().asLong();
204                     if (xValue != maxValue(bits)) {
205                         // c < x <=> !(c >= x) <=> !(x <= c) <=> !(x < c + 1)
206                         return LogicNegationNode.create(create(forY, ConstantNode.forIntegerStamp(yStamp, xValue + 1), view));
207                     }
208                 }
209                 if (forY.isJavaConstant()) {
210                     long yValue = forY.asJavaConstant().asLong();
211 
212                     // x < MAX <=> x != MAX
213                     if (yValue == maxValue(bits)) {
214                         return LogicNegationNode.create(IntegerEqualsNode.create(forX, forY, view));
215                     }
216 
217                     // x < MIN + 1 <=> x <= MIN <=> x == MIN
218                     if (yValue == minValue(bits) + 1) {
219                         return IntegerEqualsNode.create(forX, ConstantNode.forIntegerStamp(yStamp, minValue(bits)), view);
220                     }
221 
222                     // (x < c && x >= c - 1) => x == c - 1
223                     // If the constant is negative, only signed comparison is allowed.
224                     if (yValue != minValue(bits) && xStamp.lowerBound() == yValue - 1 && (yValue > 0 || getCondition() == CanonicalCondition.LT)) {
225                         return IntegerEqualsNode.create(forX, ConstantNode.forIntegerStamp(yStamp, yValue - 1), view);
226                     }
227 
228                 } else if (forY instanceof AddNode) {
229                     AddNode addNode = (AddNode) forY;
230                     LogicNode canonical = canonicalizeXLowerXPlusA(forX, addNode, false, true, view);
231                     if (canonical != null) {
232                         return canonical;
233                     }
234                 }
235                 if (forX instanceof AddNode) {
236                     AddNode addNode = (AddNode) forX;
237                     LogicNode canonical = canonicalizeXLowerXPlusA(forY, addNode, true, false, view);
238                     if (canonical != null) {
239                         return canonical;
240                     }
241                 }
242             }
243             return null;
244         }
245 
246         /**
247          * Exploit the fact that adding the (signed) MIN_VALUE on both side flips signed and
248          * unsigned comparison.
249          *
250          * In particular:
251          * <ul>
252          * <li>{@code x + MIN_VALUE < y + MIN_VALUE <=> x |<| y}</li>
253          * <li>{@code x + MIN_VALUE |<| y + MIN_VALUE <=> x < y}</li>
254          * </ul>
255          */
canonicalizeRangeFlip(ValueNode forX, ValueNode forY, int bits, boolean signed, NodeView view)256         protected static LogicNode canonicalizeRangeFlip(ValueNode forX, ValueNode forY, int bits, boolean signed, NodeView view) {
257             long min = CodeUtil.minValue(bits);
258             long xResidue = 0;
259             ValueNode left = null;
260             JavaConstant leftCst = null;
261             if (forX instanceof AddNode) {
262                 AddNode xAdd = (AddNode) forX;
263                 if (xAdd.getY().isJavaConstant() && !xAdd.getY().asJavaConstant().isDefaultForKind()) {
264                     long xCst = xAdd.getY().asJavaConstant().asLong();
265                     xResidue = xCst - min;
266                     left = xAdd.getX();
267                 }
268             } else if (forX.isJavaConstant()) {
269                 leftCst = forX.asJavaConstant();
270             }
271             if (left == null && leftCst == null) {
272                 return null;
273             }
274             long yResidue = 0;
275             ValueNode right = null;
276             JavaConstant rightCst = null;
277             if (forY instanceof AddNode) {
278                 AddNode yAdd = (AddNode) forY;
279                 if (yAdd.getY().isJavaConstant() && !yAdd.getY().asJavaConstant().isDefaultForKind()) {
280                     long yCst = yAdd.getY().asJavaConstant().asLong();
281                     yResidue = yCst - min;
282                     right = yAdd.getX();
283                 }
284             } else if (forY.isJavaConstant()) {
285                 rightCst = forY.asJavaConstant();
286             }
287             if (right == null && rightCst == null) {
288                 return null;
289             }
290             if ((xResidue == 0 && left != null) || (yResidue == 0 && right != null)) {
291                 if (left == null) {
292                     // Fortify: Suppress Null Dereference false positive
293                     assert leftCst != null;
294 
295                     left = ConstantNode.forIntegerBits(bits, leftCst.asLong() - min);
296                 } else if (xResidue != 0) {
297                     left = AddNode.create(left, ConstantNode.forIntegerBits(bits, xResidue), view);
298                 }
299                 if (right == null) {
300                     // Fortify: Suppress Null Dereference false positive
301                     assert rightCst != null;
302 
303                     right = ConstantNode.forIntegerBits(bits, rightCst.asLong() - min);
304                 } else if (yResidue != 0) {
305                     right = AddNode.create(right, ConstantNode.forIntegerBits(bits, yResidue), view);
306                 }
307                 if (signed) {
308                     return new IntegerBelowNode(left, right);
309                 } else {
310                     return new IntegerLessThanNode(left, right);
311                 }
312             }
313             return null;
314         }
315 
canonicalizeXLowerXPlusA(ValueNode forX, AddNode addNode, boolean mirrored, boolean strict, NodeView view)316         private LogicNode canonicalizeXLowerXPlusA(ValueNode forX, AddNode addNode, boolean mirrored, boolean strict, NodeView view) {
317             // x < x + a
318             // x |<| x + a
319             IntegerStamp xStamp = (IntegerStamp) forX.stamp(view);
320             IntegerStamp succeedingXStamp;
321             boolean exact;
322             if (addNode.getX() == forX && addNode.getY().stamp(view) instanceof IntegerStamp) {
323                 IntegerStamp aStamp = (IntegerStamp) addNode.getY().stamp(view);
324                 succeedingXStamp = getSucceedingStampForXLowerXPlusA(mirrored, strict, aStamp, xStamp);
325                 exact = aStamp.lowerBound() == aStamp.upperBound();
326             } else if (addNode.getY() == forX && addNode.getX().stamp(view) instanceof IntegerStamp) {
327                 IntegerStamp aStamp = (IntegerStamp) addNode.getX().stamp(view);
328                 succeedingXStamp = getSucceedingStampForXLowerXPlusA(mirrored, strict, aStamp, xStamp);
329                 exact = aStamp.lowerBound() == aStamp.upperBound();
330             } else {
331                 return null;
332             }
333             if (succeedingXStamp.join(forX.stamp(view)).isEmpty()) {
334                 return LogicConstantNode.contradiction();
335             } else if (exact && !succeedingXStamp.isEmpty()) {
336                 int bits = succeedingXStamp.getBits();
337                 if (compare(lowerBound(succeedingXStamp), minValue(bits)) > 0) {
338                     // x must be in [L..MAX] <=> x >= L <=> !(x < L)
339                     return LogicNegationNode.create(create(forX, ConstantNode.forIntegerStamp(succeedingXStamp, lowerBound(succeedingXStamp)), view));
340                 } else if (compare(upperBound(succeedingXStamp), maxValue(bits)) < 0) {
341                     // x must be in [MIN..H] <=> x <= H <=> !(H < x)
342                     return LogicNegationNode.create(create(ConstantNode.forIntegerStamp(succeedingXStamp, upperBound(succeedingXStamp)), forX, view));
343                 }
344             }
345             return null;
346         }
347 
tryFold(Stamp xStampGeneric, Stamp yStampGeneric)348         protected TriState tryFold(Stamp xStampGeneric, Stamp yStampGeneric) {
349             if (xStampGeneric instanceof IntegerStamp && yStampGeneric instanceof IntegerStamp) {
350                 IntegerStamp xStamp = (IntegerStamp) xStampGeneric;
351                 IntegerStamp yStamp = (IntegerStamp) yStampGeneric;
352                 if (compare(upperBound(xStamp), lowerBound(yStamp)) < 0) {
353                     return TriState.TRUE;
354                 }
355                 if (compare(lowerBound(xStamp), upperBound(yStamp)) >= 0) {
356                     return TriState.FALSE;
357                 }
358             }
359             return TriState.UNKNOWN;
360         }
361 
getSucceedingStampForX(IntegerStamp xStamp, IntegerStamp yStamp, boolean mirror, boolean strict)362         protected IntegerStamp getSucceedingStampForX(IntegerStamp xStamp, IntegerStamp yStamp, boolean mirror, boolean strict) {
363             int bits = xStamp.getBits();
364             assert yStamp.getBits() == bits;
365             if (mirror) {
366                 long low = lowerBound(yStamp);
367                 if (strict) {
368                     if (low == maxValue(bits)) {
369                         return null;
370                     }
371                     low += 1;
372                 }
373                 if (compare(low, lowerBound(xStamp)) > 0 || upperBound(xStamp) != (xStamp.upperBound() & mask(xStamp.getBits()))) {
374                     return forInteger(bits, low, upperBound(xStamp));
375                 }
376             } else {
377                 // x < y, i.e., x < y <= Y_UPPER_BOUND so x <= Y_UPPER_BOUND - 1
378                 long low = upperBound(yStamp);
379                 if (strict) {
380                     if (low == minValue(bits)) {
381                         return null;
382                     }
383                     low -= 1;
384                 }
385                 if (compare(low, upperBound(xStamp)) < 0 || lowerBound(xStamp) != (xStamp.lowerBound() & mask(xStamp.getBits()))) {
386                     return forInteger(bits, lowerBound(xStamp), low);
387                 }
388             }
389             return null;
390         }
391 
getSucceedingStampForXLowerXPlusA(boolean mirrored, boolean strict, IntegerStamp aStamp, IntegerStamp xStamp)392         protected IntegerStamp getSucceedingStampForXLowerXPlusA(boolean mirrored, boolean strict, IntegerStamp aStamp, IntegerStamp xStamp) {
393             int bits = aStamp.getBits();
394             long min = minValue(bits);
395             long max = maxValue(bits);
396 
397             /*
398              * if x < x + a <=> x + a didn't overflow:
399              *
400              * x is outside ]MAX - a, MAX], i.e., inside [MIN, MAX - a]
401              *
402              * if a is negative those bounds wrap around correctly.
403              *
404              * If a is exactly zero this gives an unbounded stamp (any integer) in the positive case
405              * and an empty stamp in the negative case: if x |<| x is true, then either x has no
406              * value or any value...
407              *
408              * This does not use upper/lowerBound from LowerOp because it's about the (signed)
409              * addition not the comparison.
410              */
411             if (mirrored) {
412                 if (aStamp.contains(0)) {
413                     // a may be zero
414                     return aStamp.unrestricted();
415                 }
416                 return forInteger(bits, min(max - aStamp.lowerBound() + 1, max - aStamp.upperBound() + 1, bits), min(max, upperBound(xStamp)));
417             } else {
418                 long aLower = aStamp.lowerBound();
419                 long aUpper = aStamp.upperBound();
420                 if (strict) {
421                     if (aLower == 0) {
422                         aLower = 1;
423                     }
424                     if (aUpper == 0) {
425                         aUpper = -1;
426                     }
427                     if (aLower > aUpper) {
428                         // impossible
429                         return aStamp.empty();
430                     }
431                 }
432                 if (aLower < 0 && aUpper > 0) {
433                     // a may be zero
434                     return aStamp.unrestricted();
435                 }
436                 return forInteger(bits, min, max(max - aLower, max - aUpper, bits));
437             }
438         }
439     }
440 }
441