1 /*
2  * Copyright (c) 2017, 2019, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 
25 package org.graalvm.compiler.replacements.test;
26 
27 import static org.junit.Assert.assertNotNull;
28 
29 import java.util.ArrayList;
30 import java.util.Collection;
31 import java.util.List;
32 
33 import org.graalvm.compiler.core.common.type.IntegerStamp;
34 import org.graalvm.compiler.core.common.type.StampFactory;
35 import org.graalvm.compiler.core.test.GraalCompilerTest;
36 import org.graalvm.compiler.graph.Node;
37 import org.graalvm.compiler.nodes.NodeView;
38 import org.graalvm.compiler.nodes.ParameterNode;
39 import org.graalvm.compiler.nodes.PiNode;
40 import org.graalvm.compiler.nodes.ReturnNode;
41 import org.graalvm.compiler.nodes.StructuredGraph;
42 import org.graalvm.compiler.nodes.StructuredGraph.AllowAssumptions;
43 import org.graalvm.compiler.nodes.ValueNode;
44 import org.graalvm.compiler.nodes.spi.LoweringTool;
45 import org.graalvm.compiler.phases.common.CanonicalizerPhase;
46 import org.graalvm.compiler.phases.common.GuardLoweringPhase;
47 import org.graalvm.compiler.phases.common.LoweringPhase;
48 import org.graalvm.compiler.phases.tiers.HighTierContext;
49 import org.graalvm.compiler.phases.tiers.MidTierContext;
50 import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerExactArithmeticNode;
51 import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerExactArithmeticSplitNode;
52 import org.junit.Assert;
53 import org.junit.Test;
54 import org.junit.runner.RunWith;
55 import org.junit.runners.Parameterized;
56 import org.junit.runners.Parameterized.Parameters;
57 
58 @RunWith(Parameterized.class)
59 public class IntegerExactFoldTest extends GraalCompilerTest {
60     private final long lowerBoundA;
61     private final long upperBoundA;
62     private final long lowerBoundB;
63     private final long upperBoundB;
64     private final int bits;
65     private final Operation operation;
66 
IntegerExactFoldTest(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, Operation operation)67     public IntegerExactFoldTest(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, Operation operation) {
68         this.lowerBoundA = lowerBoundA;
69         this.upperBoundA = upperBoundA;
70         this.lowerBoundB = lowerBoundB;
71         this.upperBoundB = upperBoundB;
72         this.bits = bits;
73         this.operation = operation;
74 
75         assert bits == 32 || bits == 64;
76         assert lowerBoundA <= upperBoundA;
77         assert lowerBoundB <= upperBoundB;
78         assert bits == 64 || isInteger(lowerBoundA);
79         assert bits == 64 || isInteger(upperBoundA);
80         assert bits == 64 || isInteger(lowerBoundB);
81         assert bits == 64 || isInteger(upperBoundB);
82     }
83 
84     @Test
testFolding()85     public void testFolding() {
86         StructuredGraph graph = prepareGraph();
87         IntegerStamp a = StampFactory.forInteger(bits, lowerBoundA, upperBoundA);
88         IntegerStamp b = StampFactory.forInteger(bits, lowerBoundB, upperBoundB);
89 
90         List<ParameterNode> params = graph.getNodes(ParameterNode.TYPE).snapshot();
91         params.get(0).replaceAtMatchingUsages(graph.addOrUnique(new PiNode(params.get(0), a)), x -> x instanceof IntegerExactArithmeticNode);
92         params.get(1).replaceAtMatchingUsages(graph.addOrUnique(new PiNode(params.get(1), b)), x -> x instanceof IntegerExactArithmeticNode);
93 
94         Node originalNode = graph.getNodes().filter(x -> x instanceof IntegerExactArithmeticNode).first();
95         assertNotNull("original node must be in the graph", originalNode);
96 
97         createCanonicalizerPhase().apply(graph, getDefaultHighTierContext());
98 
99         ValueNode node = findNode(graph);
100         boolean overflowExpected = node instanceof IntegerExactArithmeticNode;
101 
102         IntegerStamp resultStamp = (IntegerStamp) node.stamp(NodeView.DEFAULT);
103         operation.verifyOverflow(lowerBoundA, upperBoundA, lowerBoundB, upperBoundB, bits, overflowExpected, resultStamp);
104     }
105 
106     @Test
testFoldingAfterLowering()107     public void testFoldingAfterLowering() {
108         StructuredGraph graph = prepareGraph();
109 
110         Node originalNode = graph.getNodes().filter(x -> x instanceof IntegerExactArithmeticNode).first();
111         assertNotNull("original node must be in the graph", originalNode);
112         CanonicalizerPhase canonicalizer = createCanonicalizerPhase();
113         HighTierContext highTierContext = getDefaultHighTierContext();
114         new LoweringPhase(canonicalizer, LoweringTool.StandardLoweringStage.HIGH_TIER).apply(graph, highTierContext);
115         MidTierContext midTierContext = getDefaultMidTierContext();
116         new GuardLoweringPhase().apply(graph, midTierContext);
117         createCanonicalizerPhase().apply(graph, midTierContext);
118 
119         IntegerExactArithmeticSplitNode loweredNode = graph.getNodes().filter(IntegerExactArithmeticSplitNode.class).first();
120         assertNotNull("the lowered node must be in the graph", loweredNode);
121 
122         loweredNode.getX().setStamp(StampFactory.forInteger(bits, lowerBoundA, upperBoundA));
123         loweredNode.getY().setStamp(StampFactory.forInteger(bits, lowerBoundB, upperBoundB));
124         createCanonicalizerPhase().apply(graph, midTierContext);
125 
126         ValueNode node = findNode(graph);
127         boolean overflowExpected = node instanceof IntegerExactArithmeticSplitNode;
128 
129         IntegerStamp resultStamp = (IntegerStamp) node.stamp(NodeView.DEFAULT);
130         operation.verifyOverflow(lowerBoundA, upperBoundA, lowerBoundB, upperBoundB, bits, overflowExpected, resultStamp);
131     }
132 
isInteger(long value)133     private static boolean isInteger(long value) {
134         return value >= Integer.MIN_VALUE && value <= Integer.MAX_VALUE;
135     }
136 
findNode(StructuredGraph graph)137     private static ValueNode findNode(StructuredGraph graph) {
138         ValueNode resultNode = graph.getNodes().filter(ReturnNode.class).first().result();
139         assertNotNull("some node must be the returned value", resultNode);
140         return resultNode;
141     }
142 
prepareGraph()143     protected StructuredGraph prepareGraph() {
144         String snippet = "snippetInt" + bits;
145         StructuredGraph graph = parseEager(getResolvedJavaMethod(operation.getClass(), snippet), AllowAssumptions.NO);
146         HighTierContext context = getDefaultHighTierContext();
147         createCanonicalizerPhase().apply(graph, context);
148         return graph;
149     }
150 
addTest(ArrayList<Object[]> tests, long lowerBound1, long upperBound1, long lowerBound2, long upperBound2, int bits, Operation operation)151     private static void addTest(ArrayList<Object[]> tests, long lowerBound1, long upperBound1, long lowerBound2, long upperBound2, int bits, Operation operation) {
152         tests.add(new Object[]{lowerBound1, upperBound1, lowerBound2, upperBound2, bits, operation});
153     }
154 
155     @Parameters(name = "a[{0} / {1}], b[{2} / {3}], bits={4}, operation={5}")
data()156     public static Collection<Object[]> data() {
157         ArrayList<Object[]> tests = new ArrayList<>();
158 
159         Operation[] operations = new Operation[]{new AddOperation(), new SubOperation(), new MulOperation()};
160         for (Operation operation : operations) {
161             for (int bits : new int[]{32, 64}) {
162                 // zero related
163                 addTest(tests, 0, 0, 1, 1, bits, operation);
164                 addTest(tests, 1, 1, 0, 0, bits, operation);
165                 addTest(tests, -1, 1, 0, 1, bits, operation);
166 
167                 // bounds
168                 addTest(tests, -2, 2, 3, 3, bits, operation);
169                 addTest(tests, -1, 1, 1, 1, bits, operation);
170                 addTest(tests, -1, 1, -1, 1, bits, operation);
171 
172                 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, Integer.MAX_VALUE - 0xF, Integer.MAX_VALUE, bits, operation);
173                 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, -1, -1, bits, operation);
174                 addTest(tests, Integer.MAX_VALUE, Integer.MAX_VALUE, -1, -1, bits, operation);
175                 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE, -1, -1, bits, operation);
176                 addTest(tests, Integer.MAX_VALUE, Integer.MAX_VALUE, 1, 1, bits, operation);
177                 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE, 1, 1, bits, operation);
178             }
179 
180             // bit-specific test cases
181             addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, Integer.MAX_VALUE - 0xF, Integer.MAX_VALUE, 64, operation);
182             addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, -1, -1, 64, operation);
183         }
184 
185         return tests;
186     }
187 
188     private abstract static class Operation {
verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp)189         abstract void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp);
190     }
191 
192     private static final class AddOperation extends Operation {
193         @Override
verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp)194         public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) {
195             try {
196                 long res = addExact(lowerBoundA, lowerBoundB, bits);
197                 resultStamp.contains(res);
198                 res = addExact(upperBoundA, upperBoundB, bits);
199                 resultStamp.contains(res);
200                 Assert.assertFalse(overflowExpected);
201             } catch (ArithmeticException e) {
202                 Assert.assertTrue(overflowExpected);
203             }
204         }
205 
addExact(long x, long y, int bits)206         private static long addExact(long x, long y, int bits) {
207             if (bits == 32) {
208                 return Math.addExact((int) x, (int) y);
209             } else {
210                 return Math.addExact(x, y);
211             }
212         }
213 
214         @SuppressWarnings("unused")
snippetInt32(int a, int b)215         public static int snippetInt32(int a, int b) {
216             return Math.addExact(a, b);
217         }
218 
219         @SuppressWarnings("unused")
snippetInt64(long a, long b)220         public static long snippetInt64(long a, long b) {
221             return Math.addExact(a, b);
222         }
223     }
224 
225     private static final class SubOperation extends Operation {
226         @Override
verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp)227         public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) {
228             try {
229                 long res = subExact(lowerBoundA, upperBoundB, bits);
230                 Assert.assertTrue(resultStamp.contains(res));
231                 res = subExact(upperBoundA, lowerBoundB, bits);
232                 Assert.assertTrue(resultStamp.contains(res));
233                 Assert.assertFalse(overflowExpected);
234             } catch (ArithmeticException e) {
235                 Assert.assertTrue(overflowExpected);
236             }
237         }
238 
subExact(long x, long y, int bits)239         private static long subExact(long x, long y, int bits) {
240             if (bits == 32) {
241                 return Math.subtractExact((int) x, (int) y);
242             } else {
243                 return Math.subtractExact(x, y);
244             }
245         }
246 
247         @SuppressWarnings("unused")
snippetInt32(int a, int b)248         public static int snippetInt32(int a, int b) {
249             return Math.subtractExact(a, b);
250         }
251 
252         @SuppressWarnings("unused")
snippetInt64(long a, long b)253         public static long snippetInt64(long a, long b) {
254             return Math.subtractExact(a, b);
255         }
256     }
257 
258     private static final class MulOperation extends Operation {
259         @Override
verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp)260         public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) {
261             // now check for all values in the stamp whether their products overflow overflow
262             boolean overflowOccurred = false;
263 
264             for (long l1 = lowerBoundA; l1 <= upperBoundA; l1++) {
265                 for (long l2 = lowerBoundB; l2 <= upperBoundB; l2++) {
266                     try {
267                         long res = mulExact(l1, l2, bits);
268                         Assert.assertTrue(resultStamp.contains(res));
269                     } catch (ArithmeticException e) {
270                         overflowOccurred = true;
271                     }
272                     if (l2 == Long.MAX_VALUE) {
273                         // do not want to overflow the check loop
274                         break;
275                     }
276                 }
277                 if (l1 == Long.MAX_VALUE) {
278                     // do not want to overflow the check loop
279                     break;
280                 }
281             }
282 
283             Assert.assertEquals(overflowExpected, overflowOccurred);
284         }
285 
mulExact(long x, long y, int bits)286         private static long mulExact(long x, long y, int bits) {
287             if (bits == 32) {
288                 return Math.multiplyExact((int) x, (int) y);
289             } else {
290                 return Math.multiplyExact(x, y);
291             }
292         }
293 
294         @SuppressWarnings("unused")
snippetInt32(int a, int b)295         public static int snippetInt32(int a, int b) {
296             return Math.multiplyExact(a, b);
297         }
298 
299         @SuppressWarnings("unused")
snippetInt64(long a, long b)300         public static long snippetInt64(long a, long b) {
301             return Math.multiplyExact(a, b);
302         }
303     }
304 }
305