1 /*
2  * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
3  *  DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  *  This code is free software; you can redistribute it and/or modify it
6  *  under the terms of the GNU General Public License version 2 only, as
7  *  published by the Free Software Foundation.
8  *
9  *  This code is distributed in the hope that it will be useful, but WITHOUT
10  *  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  *  FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  *  version 2 for more details (a copy is included in the LICENSE file that
13  *  accompanied this code).
14  *
15  *  You should have received a copy of the GNU General Public License version
16  *  2 along with this work; if not, write to the Free Software Foundation,
17  *  Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  *   Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  *  or visit www.oracle.com if you need additional information or have any
21  *  questions.
22  *
23  */
24 
25 /*
26  * @test
27  * @requires ((os.arch == "amd64" | os.arch == "x86_64") & sun.arch.data.model == "64") | os.arch == "aarch64"
28  * @modules jdk.incubator.foreign/jdk.internal.foreign
29  * @build NativeTestHelper CallGeneratorHelper TestUpcall
30  *
31  * @run testng/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:-VerifyDependencies
32  *   --enable-native-access=ALL-UNNAMED -Dgenerator.sample.factor=17
33  *   TestUpcall
34  */
35 
36 import jdk.incubator.foreign.CLinker;
37 import jdk.incubator.foreign.FunctionDescriptor;
38 import jdk.incubator.foreign.SymbolLookup;
39 import jdk.incubator.foreign.MemoryAddress;
40 import jdk.incubator.foreign.MemoryLayout;
41 import jdk.incubator.foreign.MemorySegment;
42 
43 import jdk.incubator.foreign.ResourceScope;
44 import org.testng.annotations.BeforeClass;
45 import org.testng.annotations.Test;
46 
47 import java.lang.invoke.MethodHandle;
48 import java.lang.invoke.MethodHandles;
49 import java.lang.invoke.MethodType;
50 import java.util.ArrayList;
51 import java.util.List;
52 import java.util.concurrent.atomic.AtomicReference;
53 import java.util.function.Consumer;
54 import java.util.stream.Collectors;
55 
56 import static java.lang.invoke.MethodHandles.insertArguments;
57 import static jdk.incubator.foreign.CLinker.C_POINTER;
58 import static org.testng.Assert.assertEquals;
59 
60 
61 public class TestUpcall extends CallGeneratorHelper {
62 
63     static {
64         System.loadLibrary("TestUpcall");
65     }
66     static CLinker abi = CLinker.getInstance();
67 
68     static final SymbolLookup LOOKUP = SymbolLookup.loaderLookup();
69 
70     static MethodHandle DUMMY;
71     static MethodHandle PASS_AND_SAVE;
72 
73     static {
74         try {
75             DUMMY = MethodHandles.lookup().findStatic(TestUpcall.class, "dummy", MethodType.methodType(void.class));
76             PASS_AND_SAVE = MethodHandles.lookup().findStatic(TestUpcall.class, "passAndSave",
77                     MethodType.methodType(Object.class, Object[].class, AtomicReference.class));
78         } catch (Throwable ex) {
79             throw new IllegalStateException(ex);
80         }
81     }
82 
83     static MemoryAddress dummyStub;
84 
85     @BeforeClass
setup()86     void setup() {
87         dummyStub = abi.upcallStub(DUMMY, FunctionDescriptor.ofVoid(), ResourceScope.newImplicitScope());
88     }
89 
90     @Test(dataProvider="functions", dataProviderClass=CallGeneratorHelper.class)
testUpcalls(int count, String fName, Ret ret, List<ParamType> paramTypes, List<StructFieldType> fields)91     public void testUpcalls(int count, String fName, Ret ret, List<ParamType> paramTypes, List<StructFieldType> fields) throws Throwable {
92         List<Consumer<Object>> returnChecks = new ArrayList<>();
93         List<Consumer<Object[]>> argChecks = new ArrayList<>();
94         MemoryAddress addr = LOOKUP.lookup(fName).get();
95         MethodType mtype = methodType(ret, paramTypes, fields);
96         try (NativeScope scope = new NativeScope()) {
97             MethodHandle mh = abi.downcallHandle(addr, scope, mtype, function(ret, paramTypes, fields));
98             Object[] args = makeArgs(scope.scope(), ret, paramTypes, fields, returnChecks, argChecks);
99             Object[] callArgs = args;
100             Object res = mh.invokeWithArguments(callArgs);
101             argChecks.forEach(c -> c.accept(args));
102             if (ret == Ret.NON_VOID) {
103                 returnChecks.forEach(c -> c.accept(res));
104             }
105         }
106     }
107 
108     @Test(dataProvider="functions", dataProviderClass=CallGeneratorHelper.class)
testUpcallsNoScope(int count, String fName, Ret ret, List<ParamType> paramTypes, List<StructFieldType> fields)109     public void testUpcallsNoScope(int count, String fName, Ret ret, List<ParamType> paramTypes, List<StructFieldType> fields) throws Throwable {
110         List<Consumer<Object>> returnChecks = new ArrayList<>();
111         List<Consumer<Object[]>> argChecks = new ArrayList<>();
112         MemoryAddress addr = LOOKUP.lookup(fName).get();
113         MethodType mtype = methodType(ret, paramTypes, fields);
114         MethodHandle mh = abi.downcallHandle(addr, IMPLICIT_ALLOCATOR, mtype, function(ret, paramTypes, fields));
115         Object[] args = makeArgs(ResourceScope.newImplicitScope(), ret, paramTypes, fields, returnChecks, argChecks);
116         Object[] callArgs = args;
117         Object res = mh.invokeWithArguments(callArgs);
118         argChecks.forEach(c -> c.accept(args));
119         if (ret == Ret.NON_VOID) {
120             returnChecks.forEach(c -> c.accept(res));
121         }
122     }
123 
methodType(Ret ret, List<ParamType> params, List<StructFieldType> fields)124     static MethodType methodType(Ret ret, List<ParamType> params, List<StructFieldType> fields) {
125         MethodType mt = ret == Ret.VOID ?
126                 MethodType.methodType(void.class) : MethodType.methodType(paramCarrier(params.get(0).layout(fields)));
127         for (ParamType p : params) {
128             mt = mt.appendParameterTypes(paramCarrier(p.layout(fields)));
129         }
130         mt = mt.appendParameterTypes(MemoryAddress.class); //the callback
131         return mt;
132     }
133 
function(Ret ret, List<ParamType> params, List<StructFieldType> fields)134     static FunctionDescriptor function(Ret ret, List<ParamType> params, List<StructFieldType> fields) {
135         List<MemoryLayout> paramLayouts = params.stream().map(p -> p.layout(fields)).collect(Collectors.toList());
136         paramLayouts.add(C_POINTER); // the callback
137         MemoryLayout[] layouts = paramLayouts.toArray(new MemoryLayout[0]);
138         return ret == Ret.VOID ?
139                 FunctionDescriptor.ofVoid(layouts) :
140                 FunctionDescriptor.of(layouts[0], layouts);
141     }
142 
makeArgs(ResourceScope scope, Ret ret, List<ParamType> params, List<StructFieldType> fields, List<Consumer<Object>> checks, List<Consumer<Object[]>> argChecks)143     static Object[] makeArgs(ResourceScope scope, Ret ret, List<ParamType> params, List<StructFieldType> fields, List<Consumer<Object>> checks, List<Consumer<Object[]>> argChecks) throws ReflectiveOperationException {
144         Object[] args = new Object[params.size() + 1];
145         for (int i = 0 ; i < params.size() ; i++) {
146             args[i] = makeArg(params.get(i).layout(fields), checks, i == 0);
147         }
148         args[params.size()] = makeCallback(scope, ret, params, fields, checks, argChecks);
149         return args;
150     }
151 
152     @SuppressWarnings("unchecked")
makeCallback(ResourceScope scope, Ret ret, List<ParamType> params, List<StructFieldType> fields, List<Consumer<Object>> checks, List<Consumer<Object[]>> argChecks)153     static MemoryAddress makeCallback(ResourceScope scope, Ret ret, List<ParamType> params, List<StructFieldType> fields, List<Consumer<Object>> checks, List<Consumer<Object[]>> argChecks) {
154         if (params.isEmpty()) {
155             return dummyStub.address();
156         }
157 
158         AtomicReference<Object[]> box = new AtomicReference<>();
159         MethodHandle mh = insertArguments(PASS_AND_SAVE, 1, box);
160         mh = mh.asCollector(Object[].class, params.size());
161 
162         for (int i = 0; i < params.size(); i++) {
163             ParamType pt = params.get(i);
164             MemoryLayout layout = pt.layout(fields);
165             Class<?> carrier = paramCarrier(layout);
166             mh = mh.asType(mh.type().changeParameterType(i, carrier));
167 
168             final int finalI = i;
169             if (carrier == MemorySegment.class) {
170                 argChecks.add(o -> assertStructEquals((MemorySegment) box.get()[finalI], (MemorySegment) o[finalI], layout));
171             } else {
172                 argChecks.add(o -> assertEquals(box.get()[finalI], o[finalI]));
173             }
174         }
175 
176         ParamType firstParam = params.get(0);
177         MemoryLayout firstlayout = firstParam.layout(fields);
178         Class<?> firstCarrier = paramCarrier(firstlayout);
179 
180         if (firstCarrier == MemorySegment.class) {
181             checks.add(o -> assertStructEquals((MemorySegment) box.get()[0], (MemorySegment) o, firstlayout));
182         } else {
183             checks.add(o -> assertEquals(o, box.get()[0]));
184         }
185 
186         mh = mh.asType(mh.type().changeReturnType(ret == Ret.VOID ? void.class : firstCarrier));
187 
188         MemoryLayout[] paramLayouts = params.stream().map(p -> p.layout(fields)).toArray(MemoryLayout[]::new);
189         FunctionDescriptor func = ret != Ret.VOID
190                 ? FunctionDescriptor.of(firstlayout, paramLayouts)
191                 : FunctionDescriptor.ofVoid(paramLayouts);
192         return abi.upcallStub(mh, func, scope);
193     }
194 
passAndSave(Object[] o, AtomicReference<Object[]> ref)195     static Object passAndSave(Object[] o, AtomicReference<Object[]> ref) {
196         for (int i = 0; i < o.length; i++) {
197             if (o[i] instanceof MemorySegment) {
198                 MemorySegment ms = (MemorySegment) o[i];
199                 MemorySegment copy = MemorySegment.allocateNative(ms.byteSize(), ResourceScope.newImplicitScope());
200                 copy.copyFrom(ms);
201                 o[i] = copy;
202             }
203         }
204         ref.set(o);
205         return o[0];
206     }
207 
dummy()208     static void dummy() {
209         //do nothing
210     }
211 }
212