1 /*
2  * Copyright (c) 2011, 2018, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 
25 package org.graalvm.compiler.nodes.calc;
26 
27 import static org.graalvm.compiler.nodeinfo.NodeCycles.CYCLES_2;
28 
29 import org.graalvm.compiler.core.common.type.ArithmeticOpTable;
30 import org.graalvm.compiler.core.common.type.IntegerStamp;
31 import org.graalvm.compiler.core.common.type.ArithmeticOpTable.BinaryOp;
32 import org.graalvm.compiler.core.common.type.ArithmeticOpTable.BinaryOp.Mul;
33 import org.graalvm.compiler.core.common.type.Stamp;
34 import org.graalvm.compiler.graph.NodeClass;
35 import org.graalvm.compiler.graph.spi.Canonicalizable.BinaryCommutative;
36 import org.graalvm.compiler.graph.spi.CanonicalizerTool;
37 import org.graalvm.compiler.lir.gen.ArithmeticLIRGeneratorTool;
38 import org.graalvm.compiler.nodeinfo.NodeInfo;
39 import org.graalvm.compiler.nodes.ConstantNode;
40 import org.graalvm.compiler.nodes.NodeView;
41 import org.graalvm.compiler.nodes.ValueNode;
42 import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool;
43 
44 import jdk.vm.ci.code.CodeUtil;
45 import jdk.vm.ci.meta.Constant;
46 import jdk.vm.ci.meta.PrimitiveConstant;
47 import jdk.vm.ci.meta.Value;
48 
49 @NodeInfo(shortName = "*", cycles = CYCLES_2)
50 public class MulNode extends BinaryArithmeticNode<Mul> implements NarrowableArithmeticNode, BinaryCommutative<ValueNode> {
51 
52     public static final NodeClass<MulNode> TYPE = NodeClass.create(MulNode.class);
53 
MulNode(ValueNode x, ValueNode y)54     public MulNode(ValueNode x, ValueNode y) {
55         this(TYPE, x, y);
56     }
57 
MulNode(NodeClass<? extends MulNode> c, ValueNode x, ValueNode y)58     protected MulNode(NodeClass<? extends MulNode> c, ValueNode x, ValueNode y) {
59         super(c, ArithmeticOpTable::getMul, x, y);
60     }
61 
create(ValueNode x, ValueNode y, NodeView view)62     public static ValueNode create(ValueNode x, ValueNode y, NodeView view) {
63         BinaryOp<Mul> op = ArithmeticOpTable.forStamp(x.stamp(view)).getMul();
64         Stamp stamp = op.foldStamp(x.stamp(view), y.stamp(view));
65         ConstantNode tryConstantFold = tryConstantFold(op, x, y, stamp, view);
66         if (tryConstantFold != null) {
67             return tryConstantFold;
68         }
69         return canonical(null, op, stamp, x, y, view);
70     }
71 
72     @Override
canonical(CanonicalizerTool tool, ValueNode forX, ValueNode forY)73     public ValueNode canonical(CanonicalizerTool tool, ValueNode forX, ValueNode forY) {
74         ValueNode ret = super.canonical(tool, forX, forY);
75         if (ret != this) {
76             return ret;
77         }
78 
79         if (forX.isConstant() && !forY.isConstant()) {
80             // we try to swap and canonicalize
81             ValueNode improvement = canonical(tool, forY, forX);
82             if (improvement != this) {
83                 return improvement;
84             }
85             // if this fails we only swap
86             return new MulNode(forY, forX);
87         }
88         BinaryOp<Mul> op = getOp(forX, forY);
89         NodeView view = NodeView.from(tool);
90         return canonical(this, op, stamp(view), forX, forY, view);
91     }
92 
canonical(MulNode self, BinaryOp<Mul> op, Stamp stamp, ValueNode forX, ValueNode forY, NodeView view)93     private static ValueNode canonical(MulNode self, BinaryOp<Mul> op, Stamp stamp, ValueNode forX, ValueNode forY, NodeView view) {
94         if (forY.isConstant()) {
95             Constant c = forY.asConstant();
96             if (op.isNeutral(c)) {
97                 return forX;
98             }
99 
100             if (c instanceof PrimitiveConstant && ((PrimitiveConstant) c).getJavaKind().isNumericInteger()) {
101                 long i = ((PrimitiveConstant) c).asLong();
102                 ValueNode result = canonical(stamp, forX, i, view);
103                 if (result != null) {
104                     return result;
105                 }
106             }
107 
108             if (op.isAssociative()) {
109                 // canonicalize expressions like "(a * 1) * 2"
110                 return reassociate(self != null ? self : (MulNode) new MulNode(forX, forY).maybeCommuteInputs(), ValueNode.isConstantPredicate(), forX, forY, view);
111             }
112         }
113         return self != null ? self : new MulNode(forX, forY).maybeCommuteInputs();
114     }
115 
canonical(Stamp stamp, ValueNode forX, long i, NodeView view)116     public static ValueNode canonical(Stamp stamp, ValueNode forX, long i, NodeView view) {
117         if (i == 0) {
118             return ConstantNode.forIntegerStamp(stamp, 0);
119         } else if (i == 1) {
120             return forX;
121         } else if (i == -1) {
122             return NegateNode.create(forX, view);
123         } else if (i > 0) {
124             if (CodeUtil.isPowerOf2(i)) {
125                 return new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(i)));
126             } else if (CodeUtil.isPowerOf2(i - 1)) {
127                 return AddNode.create(new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(i - 1))), forX, view);
128             } else if (CodeUtil.isPowerOf2(i + 1)) {
129                 return SubNode.create(new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(i + 1))), forX, view);
130             } else {
131                 int bitCount = Long.bitCount(i);
132                 long highestBitValue = Long.highestOneBit(i);
133                 if (bitCount == 2) {
134                     // e.g., 0b1000_0010
135                     long lowerBitValue = i - highestBitValue;
136                     assert highestBitValue > 0 && lowerBitValue > 0;
137                     ValueNode left = new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(highestBitValue)));
138                     ValueNode right = lowerBitValue == 1 ? forX : new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(lowerBitValue)));
139                     return AddNode.create(left, right, view);
140                 } else {
141                     // e.g., 0b1111_1101
142                     int shiftToRoundUpToPowerOf2 = CodeUtil.log2(highestBitValue) + 1;
143                     long subValue = (1 << shiftToRoundUpToPowerOf2) - i;
144                     if (CodeUtil.isPowerOf2(subValue) && shiftToRoundUpToPowerOf2 < ((IntegerStamp) stamp).getBits()) {
145                         assert CodeUtil.log2(subValue) >= 1;
146                         ValueNode left = new LeftShiftNode(forX, ConstantNode.forInt(shiftToRoundUpToPowerOf2));
147                         ValueNode right = new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(subValue)));
148                         return SubNode.create(left, right, view);
149                     }
150                 }
151             }
152         } else if (i < 0) {
153             if (CodeUtil.isPowerOf2(-i)) {
154                 return NegateNode.create(LeftShiftNode.create(forX, ConstantNode.forInt(CodeUtil.log2(-i)), view), view);
155             }
156         }
157         return null;
158     }
159 
160     @Override
generate(NodeLIRBuilderTool nodeValueMap, ArithmeticLIRGeneratorTool gen)161     public void generate(NodeLIRBuilderTool nodeValueMap, ArithmeticLIRGeneratorTool gen) {
162         Value op1 = nodeValueMap.operand(getX());
163         Value op2 = nodeValueMap.operand(getY());
164         if (shouldSwapInputs(nodeValueMap)) {
165             Value tmp = op1;
166             op1 = op2;
167             op2 = tmp;
168         }
169         nodeValueMap.setResult(this, gen.emitMul(op1, op2, false));
170     }
171 }
172