1 /* 2 * Copyright (c) 2017, 2019, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 25 package org.graalvm.compiler.replacements.test; 26 27 import static org.junit.Assert.assertNotNull; 28 29 import java.util.ArrayList; 30 import java.util.Collection; 31 import java.util.List; 32 33 import org.graalvm.compiler.core.common.type.IntegerStamp; 34 import org.graalvm.compiler.core.common.type.StampFactory; 35 import org.graalvm.compiler.core.test.GraalCompilerTest; 36 import org.graalvm.compiler.graph.Node; 37 import org.graalvm.compiler.nodes.NodeView; 38 import org.graalvm.compiler.nodes.ParameterNode; 39 import org.graalvm.compiler.nodes.PiNode; 40 import org.graalvm.compiler.nodes.ReturnNode; 41 import org.graalvm.compiler.nodes.StructuredGraph; 42 import org.graalvm.compiler.nodes.StructuredGraph.AllowAssumptions; 43 import org.graalvm.compiler.nodes.ValueNode; 44 import org.graalvm.compiler.nodes.spi.LoweringTool; 45 import org.graalvm.compiler.phases.common.CanonicalizerPhase; 46 import org.graalvm.compiler.phases.common.GuardLoweringPhase; 47 import org.graalvm.compiler.phases.common.LoweringPhase; 48 import org.graalvm.compiler.phases.tiers.HighTierContext; 49 import org.graalvm.compiler.phases.tiers.MidTierContext; 50 import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerExactArithmeticNode; 51 import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerExactArithmeticSplitNode; 52 import org.junit.Assert; 53 import org.junit.Test; 54 import org.junit.runner.RunWith; 55 import org.junit.runners.Parameterized; 56 import org.junit.runners.Parameterized.Parameters; 57 58 @RunWith(Parameterized.class) 59 public class IntegerExactFoldTest extends GraalCompilerTest { 60 private final long lowerBoundA; 61 private final long upperBoundA; 62 private final long lowerBoundB; 63 private final long upperBoundB; 64 private final int bits; 65 private final Operation operation; 66 IntegerExactFoldTest(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, Operation operation)67 public IntegerExactFoldTest(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, Operation operation) { 68 this.lowerBoundA = lowerBoundA; 69 this.upperBoundA = upperBoundA; 70 this.lowerBoundB = lowerBoundB; 71 this.upperBoundB = upperBoundB; 72 this.bits = bits; 73 this.operation = operation; 74 75 assert bits == 32 || bits == 64; 76 assert lowerBoundA <= upperBoundA; 77 assert lowerBoundB <= upperBoundB; 78 assert bits == 64 || isInteger(lowerBoundA); 79 assert bits == 64 || isInteger(upperBoundA); 80 assert bits == 64 || isInteger(lowerBoundB); 81 assert bits == 64 || isInteger(upperBoundB); 82 } 83 84 @Test testFolding()85 public void testFolding() { 86 StructuredGraph graph = prepareGraph(); 87 IntegerStamp a = StampFactory.forInteger(bits, lowerBoundA, upperBoundA); 88 IntegerStamp b = StampFactory.forInteger(bits, lowerBoundB, upperBoundB); 89 90 List<ParameterNode> params = graph.getNodes(ParameterNode.TYPE).snapshot(); 91 params.get(0).replaceAtMatchingUsages(graph.addOrUnique(new PiNode(params.get(0), a)), x -> x instanceof IntegerExactArithmeticNode); 92 params.get(1).replaceAtMatchingUsages(graph.addOrUnique(new PiNode(params.get(1), b)), x -> x instanceof IntegerExactArithmeticNode); 93 94 Node originalNode = graph.getNodes().filter(x -> x instanceof IntegerExactArithmeticNode).first(); 95 assertNotNull("original node must be in the graph", originalNode); 96 97 createCanonicalizerPhase().apply(graph, getDefaultHighTierContext()); 98 99 ValueNode node = findNode(graph); 100 boolean overflowExpected = node instanceof IntegerExactArithmeticNode; 101 102 IntegerStamp resultStamp = (IntegerStamp) node.stamp(NodeView.DEFAULT); 103 operation.verifyOverflow(lowerBoundA, upperBoundA, lowerBoundB, upperBoundB, bits, overflowExpected, resultStamp); 104 } 105 106 @Test testFoldingAfterLowering()107 public void testFoldingAfterLowering() { 108 StructuredGraph graph = prepareGraph(); 109 110 Node originalNode = graph.getNodes().filter(x -> x instanceof IntegerExactArithmeticNode).first(); 111 assertNotNull("original node must be in the graph", originalNode); 112 CanonicalizerPhase canonicalizer = createCanonicalizerPhase(); 113 HighTierContext highTierContext = getDefaultHighTierContext(); 114 new LoweringPhase(canonicalizer, LoweringTool.StandardLoweringStage.HIGH_TIER).apply(graph, highTierContext); 115 MidTierContext midTierContext = getDefaultMidTierContext(); 116 new GuardLoweringPhase().apply(graph, midTierContext); 117 createCanonicalizerPhase().apply(graph, midTierContext); 118 119 IntegerExactArithmeticSplitNode loweredNode = graph.getNodes().filter(IntegerExactArithmeticSplitNode.class).first(); 120 assertNotNull("the lowered node must be in the graph", loweredNode); 121 122 loweredNode.getX().setStamp(StampFactory.forInteger(bits, lowerBoundA, upperBoundA)); 123 loweredNode.getY().setStamp(StampFactory.forInteger(bits, lowerBoundB, upperBoundB)); 124 createCanonicalizerPhase().apply(graph, midTierContext); 125 126 ValueNode node = findNode(graph); 127 boolean overflowExpected = node instanceof IntegerExactArithmeticSplitNode; 128 129 IntegerStamp resultStamp = (IntegerStamp) node.stamp(NodeView.DEFAULT); 130 operation.verifyOverflow(lowerBoundA, upperBoundA, lowerBoundB, upperBoundB, bits, overflowExpected, resultStamp); 131 } 132 isInteger(long value)133 private static boolean isInteger(long value) { 134 return value >= Integer.MIN_VALUE && value <= Integer.MAX_VALUE; 135 } 136 findNode(StructuredGraph graph)137 private static ValueNode findNode(StructuredGraph graph) { 138 ValueNode resultNode = graph.getNodes().filter(ReturnNode.class).first().result(); 139 assertNotNull("some node must be the returned value", resultNode); 140 return resultNode; 141 } 142 prepareGraph()143 protected StructuredGraph prepareGraph() { 144 String snippet = "snippetInt" + bits; 145 StructuredGraph graph = parseEager(getResolvedJavaMethod(operation.getClass(), snippet), AllowAssumptions.NO); 146 HighTierContext context = getDefaultHighTierContext(); 147 createCanonicalizerPhase().apply(graph, context); 148 return graph; 149 } 150 addTest(ArrayList<Object[]> tests, long lowerBound1, long upperBound1, long lowerBound2, long upperBound2, int bits, Operation operation)151 private static void addTest(ArrayList<Object[]> tests, long lowerBound1, long upperBound1, long lowerBound2, long upperBound2, int bits, Operation operation) { 152 tests.add(new Object[]{lowerBound1, upperBound1, lowerBound2, upperBound2, bits, operation}); 153 } 154 155 @Parameters(name = "a[{0} / {1}], b[{2} / {3}], bits={4}, operation={5}") data()156 public static Collection<Object[]> data() { 157 ArrayList<Object[]> tests = new ArrayList<>(); 158 159 Operation[] operations = new Operation[]{new AddOperation(), new SubOperation(), new MulOperation()}; 160 for (Operation operation : operations) { 161 for (int bits : new int[]{32, 64}) { 162 // zero related 163 addTest(tests, 0, 0, 1, 1, bits, operation); 164 addTest(tests, 1, 1, 0, 0, bits, operation); 165 addTest(tests, -1, 1, 0, 1, bits, operation); 166 167 // bounds 168 addTest(tests, -2, 2, 3, 3, bits, operation); 169 addTest(tests, -1, 1, 1, 1, bits, operation); 170 addTest(tests, -1, 1, -1, 1, bits, operation); 171 172 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, Integer.MAX_VALUE - 0xF, Integer.MAX_VALUE, bits, operation); 173 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, -1, -1, bits, operation); 174 addTest(tests, Integer.MAX_VALUE, Integer.MAX_VALUE, -1, -1, bits, operation); 175 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE, -1, -1, bits, operation); 176 addTest(tests, Integer.MAX_VALUE, Integer.MAX_VALUE, 1, 1, bits, operation); 177 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE, 1, 1, bits, operation); 178 } 179 180 // bit-specific test cases 181 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, Integer.MAX_VALUE - 0xF, Integer.MAX_VALUE, 64, operation); 182 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, -1, -1, 64, operation); 183 } 184 185 return tests; 186 } 187 188 private abstract static class Operation { verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp)189 abstract void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp); 190 } 191 192 private static final class AddOperation extends Operation { 193 @Override verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp)194 public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) { 195 try { 196 long res = addExact(lowerBoundA, lowerBoundB, bits); 197 resultStamp.contains(res); 198 res = addExact(upperBoundA, upperBoundB, bits); 199 resultStamp.contains(res); 200 Assert.assertFalse(overflowExpected); 201 } catch (ArithmeticException e) { 202 Assert.assertTrue(overflowExpected); 203 } 204 } 205 addExact(long x, long y, int bits)206 private static long addExact(long x, long y, int bits) { 207 if (bits == 32) { 208 return Math.addExact((int) x, (int) y); 209 } else { 210 return Math.addExact(x, y); 211 } 212 } 213 214 @SuppressWarnings("unused") snippetInt32(int a, int b)215 public static int snippetInt32(int a, int b) { 216 return Math.addExact(a, b); 217 } 218 219 @SuppressWarnings("unused") snippetInt64(long a, long b)220 public static long snippetInt64(long a, long b) { 221 return Math.addExact(a, b); 222 } 223 } 224 225 private static final class SubOperation extends Operation { 226 @Override verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp)227 public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) { 228 try { 229 long res = subExact(lowerBoundA, upperBoundB, bits); 230 Assert.assertTrue(resultStamp.contains(res)); 231 res = subExact(upperBoundA, lowerBoundB, bits); 232 Assert.assertTrue(resultStamp.contains(res)); 233 Assert.assertFalse(overflowExpected); 234 } catch (ArithmeticException e) { 235 Assert.assertTrue(overflowExpected); 236 } 237 } 238 subExact(long x, long y, int bits)239 private static long subExact(long x, long y, int bits) { 240 if (bits == 32) { 241 return Math.subtractExact((int) x, (int) y); 242 } else { 243 return Math.subtractExact(x, y); 244 } 245 } 246 247 @SuppressWarnings("unused") snippetInt32(int a, int b)248 public static int snippetInt32(int a, int b) { 249 return Math.subtractExact(a, b); 250 } 251 252 @SuppressWarnings("unused") snippetInt64(long a, long b)253 public static long snippetInt64(long a, long b) { 254 return Math.subtractExact(a, b); 255 } 256 } 257 258 private static final class MulOperation extends Operation { 259 @Override verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp)260 public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) { 261 // now check for all values in the stamp whether their products overflow overflow 262 boolean overflowOccurred = false; 263 264 for (long l1 = lowerBoundA; l1 <= upperBoundA; l1++) { 265 for (long l2 = lowerBoundB; l2 <= upperBoundB; l2++) { 266 try { 267 long res = mulExact(l1, l2, bits); 268 Assert.assertTrue(resultStamp.contains(res)); 269 } catch (ArithmeticException e) { 270 overflowOccurred = true; 271 } 272 if (l2 == Long.MAX_VALUE) { 273 // do not want to overflow the check loop 274 break; 275 } 276 } 277 if (l1 == Long.MAX_VALUE) { 278 // do not want to overflow the check loop 279 break; 280 } 281 } 282 283 Assert.assertEquals(overflowExpected, overflowOccurred); 284 } 285 mulExact(long x, long y, int bits)286 private static long mulExact(long x, long y, int bits) { 287 if (bits == 32) { 288 return Math.multiplyExact((int) x, (int) y); 289 } else { 290 return Math.multiplyExact(x, y); 291 } 292 } 293 294 @SuppressWarnings("unused") snippetInt32(int a, int b)295 public static int snippetInt32(int a, int b) { 296 return Math.multiplyExact(a, b); 297 } 298 299 @SuppressWarnings("unused") snippetInt64(long a, long b)300 public static long snippetInt64(long a, long b) { 301 return Math.multiplyExact(a, b); 302 } 303 } 304 } 305