1 /*
2  * Copyright (c) 2015, 2019, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 
25 package org.graalvm.compiler.replacements.test;
26 
27 import static org.graalvm.compiler.nodeinfo.InputType.Guard;
28 import static org.graalvm.compiler.nodeinfo.InputType.Memory;
29 import static org.graalvm.compiler.nodeinfo.NodeCycles.CYCLES_IGNORED;
30 import static org.graalvm.compiler.nodeinfo.NodeSize.SIZE_IGNORED;
31 import static org.hamcrest.CoreMatchers.instanceOf;
32 
33 import org.graalvm.compiler.api.replacements.ClassSubstitution;
34 import org.graalvm.compiler.api.replacements.MethodSubstitution;
35 import org.graalvm.compiler.core.common.type.StampFactory;
36 import org.graalvm.compiler.graph.NodeClass;
37 import org.graalvm.compiler.graph.iterators.NodeIterable;
38 import org.graalvm.compiler.nodeinfo.NodeInfo;
39 import org.graalvm.compiler.nodeinfo.StructuralInput.Guard;
40 import org.graalvm.compiler.nodeinfo.StructuralInput.Memory;
41 import org.graalvm.compiler.nodes.ConstantNode;
42 import org.graalvm.compiler.nodes.FixedWithNextNode;
43 import org.graalvm.compiler.nodes.ReturnNode;
44 import org.graalvm.compiler.nodes.StructuredGraph;
45 import org.graalvm.compiler.nodes.ValueNode;
46 import org.graalvm.compiler.nodes.calc.FloatingNode;
47 import org.graalvm.compiler.nodes.extended.GuardingNode;
48 import org.graalvm.compiler.nodes.graphbuilderconf.InvocationPlugins;
49 import org.graalvm.compiler.nodes.graphbuilderconf.InvocationPlugins.Registration;
50 import org.graalvm.compiler.nodes.memory.MemoryNode;
51 import org.graalvm.compiler.replacements.classfile.ClassfileBytecodeProvider;
52 import org.junit.Assert;
53 import org.junit.Test;
54 
55 import jdk.vm.ci.meta.JavaKind;
56 
57 public class SubstitutionsTest extends ReplacementsTest {
58 
59     @NodeInfo(allowedUsageTypes = {Memory}, cycles = CYCLES_IGNORED, size = SIZE_IGNORED)
60     static class TestMemory extends FixedWithNextNode implements MemoryNode {
61         private static final NodeClass<TestMemory> TYPE = NodeClass.create(TestMemory.class);
62 
TestMemory()63         protected TestMemory() {
64             super(TYPE, StampFactory.forVoid());
65         }
66 
67         @NodeIntrinsic
memory()68         public static native Memory memory();
69     }
70 
71     @NodeInfo(allowedUsageTypes = {Guard}, cycles = CYCLES_IGNORED, size = SIZE_IGNORED)
72     static class TestGuard extends FloatingNode implements GuardingNode {
73         private static final NodeClass<TestGuard> TYPE = NodeClass.create(TestGuard.class);
74 
75         @Input(Memory) MemoryNode memory;
76 
TestGuard(ValueNode memory)77         protected TestGuard(ValueNode memory) {
78             super(TYPE, StampFactory.forVoid());
79             this.memory = (MemoryNode) memory;
80         }
81 
82         @NodeIntrinsic
guard(Memory memory)83         public static native Guard guard(Memory memory);
84     }
85 
86     @NodeInfo(cycles = CYCLES_IGNORED, size = SIZE_IGNORED)
87     static class TestValue extends FloatingNode {
88         private static final NodeClass<TestValue> TYPE = NodeClass.create(TestValue.class);
89 
90         @Input(Guard) GuardingNode guard;
91 
TestValue(ValueNode guard)92         protected TestValue(ValueNode guard) {
93             super(TYPE, StampFactory.forKind(JavaKind.Int));
94             this.guard = (GuardingNode) guard;
95         }
96 
97         @NodeIntrinsic
value(Guard guard)98         public static native int value(Guard guard);
99     }
100 
101     private static class TestMethod {
102 
test()103         public static int test() {
104             return 42;
105         }
106     }
107 
108     @ClassSubstitution(TestMethod.class)
109     private static class TestMethodSubstitution {
110 
111         @MethodSubstitution
test()112         public static int test() {
113             Memory memory = TestMemory.memory();
114             Guard guard = TestGuard.guard(memory);
115             return TestValue.value(guard);
116         }
117     }
118 
119     @Override
registerInvocationPlugins(InvocationPlugins invocationPlugins)120     protected void registerInvocationPlugins(InvocationPlugins invocationPlugins) {
121         new PluginFactory_SubstitutionsTest().registerPlugins(invocationPlugins, null);
122         ClassfileBytecodeProvider bytecodeProvider = getSystemClassLoaderBytecodeProvider();
123         Registration r = new Registration(invocationPlugins, TestMethod.class, getReplacements(), bytecodeProvider);
124         r.registerMethodSubstitution(TestMethodSubstitution.class, "test");
125         super.registerInvocationPlugins(invocationPlugins);
126     }
127 
callTest()128     public static int callTest() {
129         return TestMethod.test();
130     }
131 
132     @Override
checkHighTierGraph(StructuredGraph graph)133     protected void checkHighTierGraph(StructuredGraph graph) {
134         // Check that the graph contains the expected test nodes.
135         NodeIterable<ReturnNode> retNodes = graph.getNodes().filter(ReturnNode.class);
136         Assert.assertTrue("expected exactly one ReturnNode", retNodes.count() == 1);
137         ReturnNode ret = retNodes.first();
138 
139         Assert.assertThat(ret.result(), instanceOf(TestValue.class));
140         TestValue value = (TestValue) ret.result();
141 
142         Assert.assertThat(value.guard, instanceOf(TestGuard.class));
143         TestGuard guard = (TestGuard) value.guard;
144 
145         Assert.assertThat(guard.memory, instanceOf(TestMemory.class));
146         TestMemory memory = (TestMemory) guard.memory;
147 
148         // Remove the test nodes, replacing them by the constant 42.
149         // This implicitly makes sure that the rest of the graph is valid.
150         ret.replaceFirstInput(value, graph.unique(ConstantNode.forInt(42)));
151         value.safeDelete();
152         guard.safeDelete();
153         graph.removeFixed(memory);
154     }
155 
156     @Test
snippetTest()157     public void snippetTest() {
158         test("callTest");
159     }
160 }
161