1 /*
2  * Copyright (c) 2013, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 /*
25  * @test
26  * @summary Test locating and invoking default/static method that defined
27  *          in interfaces and/or in inheritance
28  * @bug 7184826
29  * @build helper.Mod helper.Declared DefaultStaticTestData
30  * @run testng DefaultStaticInvokeTest
31  * @author Yong Lu
32  */
33 
34 import java.lang.invoke.MethodHandle;
35 import java.lang.invoke.MethodHandles;
36 import java.lang.invoke.MethodType;
37 import java.lang.reflect.Method;
38 import java.lang.reflect.Modifier;
39 import java.util.Arrays;
40 import java.util.HashMap;
41 import java.util.HashSet;
42 
43 import static org.testng.Assert.assertEquals;
44 import static org.testng.Assert.assertTrue;
45 import static org.testng.Assert.assertFalse;
46 import static org.testng.Assert.assertNotNull;
47 import static org.testng.Assert.fail;
48 import org.testng.annotations.Test;
49 
50 import static helper.Mod.*;
51 import static helper.Declared.*;
52 import helper.Mod;
53 
54 
55 public class DefaultStaticInvokeTest {
56 
57     // getMethods(): Make sure getMethods returns the expected methods.
58     @Test(dataProvider = "testCasesAll",
59             dataProviderClass = DefaultStaticTestData.class)
testGetMethods(String testTarget, Object param)60     public void testGetMethods(String testTarget, Object param)
61             throws Exception {
62         testMethods(ALL_METHODS, testTarget, param);
63     }
64 
65 
66     // getDeclaredMethods(): Make sure getDeclaredMethods returns the expected methods.
67     @Test(dataProvider = "testCasesAll",
68             dataProviderClass = DefaultStaticTestData.class)
testGetDeclaredMethods(String testTarget, Object param)69     public void testGetDeclaredMethods(String testTarget, Object param)
70             throws Exception {
71         testMethods(DECLARED_ONLY, testTarget, param);
72     }
73 
74 
75     // getMethod(): Make sure that getMethod finds all methods it should find.
76     @Test(dataProvider = "testCasesAll",
77             dataProviderClass = DefaultStaticTestData.class)
testGetMethod(String testTarget, Object param)78     public void testGetMethod(String testTarget, Object param)
79             throws Exception {
80 
81         Class<?> typeUnderTest = Class.forName(testTarget);
82 
83         MethodDesc[] descs = typeUnderTest.getAnnotationsByType(MethodDesc.class);
84 
85         for (MethodDesc desc : descs) {
86             assertTrue(isFoundByGetMethod(typeUnderTest,
87                                           desc.name(),
88                                           argTypes(param)));
89         }
90     }
91 
92 
93     // getMethod(): Make sure that getMethod does *not* find certain methods.
94     @Test(dataProvider = "testCasesAll",
95             dataProviderClass = DefaultStaticTestData.class)
testGetMethodSuperInterfaces(String testTarget, Object param)96     public void testGetMethodSuperInterfaces(String testTarget, Object param)
97             throws Exception {
98 
99         // Make sure static methods in superinterfaces are not found (unless the type under
100         // test declares a static method with the same signature).
101 
102         Class<?> typeUnderTest = Class.forName(testTarget);
103 
104         for (Class<?> interfaze : typeUnderTest.getInterfaces()) {
105 
106             for (MethodDesc desc : interfaze.getAnnotationsByType(MethodDesc.class)) {
107 
108                 boolean isStatic = desc.mod() == STATIC;
109 
110                 boolean declaredInThisType = isMethodDeclared(typeUnderTest,
111                                                               desc.name());
112 
113                 boolean expectedToBeFound = !isStatic || declaredInThisType;
114 
115                 if (expectedToBeFound)
116                     continue; // already tested in testGetMethod()
117 
118                 assertFalse(isFoundByGetMethod(typeUnderTest,
119                                                desc.name(),
120                                                argTypes(param)));
121             }
122         }
123     }
124 
125 
126     // Method.invoke(): Make sure Method.invoke returns the expected value.
127     @Test(dataProvider = "testCasesAll",
128             dataProviderClass = DefaultStaticTestData.class)
testMethodInvoke(String testTarget, Object param)129     public void testMethodInvoke(String testTarget, Object param)
130             throws Exception {
131         Class<?> typeUnderTest = Class.forName(testTarget);
132         MethodDesc[] expectedMethods = typeUnderTest.getAnnotationsByType(MethodDesc.class);
133 
134         // test the method retrieved by Class.getMethod(String, Object[])
135         for (MethodDesc toTest : expectedMethods) {
136             String name = toTest.name();
137             Method m = typeUnderTest.getMethod(name, argTypes(param));
138             testThisMethod(toTest, m, typeUnderTest, param);
139         }
140     }
141 
142 
143     // MethodHandle.invoke(): Make sure MethodHandle.invoke returns the expected value.
144     @Test(dataProvider = "testCasesAll",
145             dataProviderClass = DefaultStaticTestData.class)
testMethodHandleInvoke(String testTarget, Object param)146     public void testMethodHandleInvoke(String testTarget, Object param)
147             throws Throwable {
148         Class<?> typeUnderTest = Class.forName(testTarget);
149         MethodDesc[] expectedMethods = typeUnderTest.getAnnotationsByType(MethodDesc.class);
150 
151         for (MethodDesc toTest : expectedMethods) {
152             String mName = toTest.name();
153             Mod mod = toTest.mod();
154             if (mod != STATIC && typeUnderTest.isInterface()) {
155                 return;
156             }
157 
158             String result = null;
159             String expectedReturn = toTest.retval();
160 
161             MethodHandle methodHandle = getTestMH(typeUnderTest, mName, param);
162             if (mName.equals("staticMethod")) {
163                 result = (param == null)
164                         ? (String) methodHandle.invoke()
165                         : (String) methodHandle.invoke(param);
166             } else {
167                 result = (param == null)
168                         ? (String) methodHandle.invoke(typeUnderTest.newInstance())
169                         : (String) methodHandle.invoke(typeUnderTest.newInstance(), param);
170             }
171 
172             assertEquals(result, expectedReturn);
173         }
174 
175     }
176 
177     // Lookup.findStatic / .findVirtual: Make sure IllegalAccessException is thrown as expected.
178     @Test(dataProvider = "testClasses",
179             dataProviderClass = DefaultStaticTestData.class)
testIAE(String testTarget, Object param)180     public void testIAE(String testTarget, Object param)
181             throws ClassNotFoundException {
182 
183         Class<?> typeUnderTest = Class.forName(testTarget);
184         MethodDesc[] expectedMethods = typeUnderTest.getAnnotationsByType(MethodDesc.class);
185 
186         for (MethodDesc toTest : expectedMethods) {
187             String mName = toTest.name();
188             Mod mod = toTest.mod();
189             if (mod != STATIC && typeUnderTest.isInterface()) {
190                 continue;
191             }
192             Exception caught = null;
193             try {
194                 getTestMH(typeUnderTest, mName, param, true);
195             } catch (Exception e) {
196                 caught = e;
197             }
198             assertNotNull(caught);
199             assertEquals(caught.getClass(), IllegalAccessException.class);
200         }
201     }
202 
203 
204     private static final String[] OBJECT_METHOD_NAMES = {
205         "equals",
206         "hashCode",
207         "getClass",
208         "notify",
209         "notifyAll",
210         "toString",
211         "wait",
212         "wait",
213         "wait",};
214     private static final String LAMBDA_METHOD_NAMES = "lambda$";
215     private static final HashSet<String> OBJECT_NAMES = new HashSet<>(Arrays.asList(OBJECT_METHOD_NAMES));
216     private static final boolean DECLARED_ONLY = true;
217     private static final boolean ALL_METHODS = false;
218 
testMethods(boolean declaredOnly, String testTarget, Object param)219     private void testMethods(boolean declaredOnly, String testTarget, Object param)
220             throws Exception {
221         Class<?> typeUnderTest = Class.forName(testTarget);
222         Method[] methods = declaredOnly
223                 ? typeUnderTest.getDeclaredMethods()
224                 : typeUnderTest.getMethods();
225 
226         MethodDesc[] baseExpectedMethods = typeUnderTest.getAnnotationsByType(MethodDesc.class);
227         MethodDesc[] expectedMethods;
228 
229         // If only declared filter out non-declared from expected result
230         if (declaredOnly) {
231             int nonDeclared = 0;
232             for (MethodDesc desc : baseExpectedMethods) {
233                 if (desc.declared() == NO) {
234                     nonDeclared++;
235                 }
236             }
237             expectedMethods = new MethodDesc[baseExpectedMethods.length - nonDeclared];
238             int i = 0;
239             for (MethodDesc desc : baseExpectedMethods) {
240                 if (desc.declared() == YES) {
241                     expectedMethods[i++] = desc;
242                 }
243             }
244         } else {
245             expectedMethods = baseExpectedMethods;
246         }
247 
248         HashMap<String, Method> myMethods = new HashMap<>(methods.length);
249         for (Method m : methods) {
250             String mName = m.getName();
251             // don't add Object methods and method created from lambda expression
252             if ((!OBJECT_NAMES.contains(mName)) && (!mName.contains(LAMBDA_METHOD_NAMES))) {
253                 myMethods.put(mName, m);
254             }
255         }
256 
257         assertEquals(myMethods.size(), expectedMethods.length);
258 
259         for (MethodDesc toTest : expectedMethods) {
260 
261             String name = toTest.name();
262             Method candidate = myMethods.remove(name);
263 
264             assertNotNull(candidate);
265 
266             testThisMethod(toTest, candidate, typeUnderTest, param);
267 
268         }
269 
270         // Should be no methods left since we remove all we expect to see
271         assertTrue(myMethods.isEmpty());
272     }
273 
274 
testThisMethod(MethodDesc toTest, Method method, Class<?> typeUnderTest, Object param)275     private void testThisMethod(MethodDesc toTest, Method method,
276             Class<?> typeUnderTest, Object param) throws Exception {
277         // Test modifiers, and invoke
278         Mod mod = toTest.mod();
279         String expectedReturn = toTest.retval();
280         switch (mod) {
281             case STATIC:
282                 //assert candidate is static
283                 assertTrue(Modifier.isStatic(method.getModifiers()));
284                 assertFalse(method.isDefault());
285 
286                 // Test invoke it
287                 assertEquals(tryInvoke(method, null, param), expectedReturn);
288                 break;
289             case DEFAULT:
290                 // if typeUnderTest is a class then instantiate and invoke
291                 if (!typeUnderTest.isInterface()) {
292                     assertEquals(tryInvoke(
293                             method,
294                             typeUnderTest,
295                             param),
296                             expectedReturn);
297                 }
298 
299                 //assert candidate is default
300                 assertFalse(Modifier.isStatic(method.getModifiers()));
301                 assertTrue(method.isDefault());
302                 break;
303             case REGULAR:
304                 // if typeUnderTest must be a class
305                 assertEquals(tryInvoke(
306                         method,
307                         typeUnderTest,
308                         param),
309                         expectedReturn);
310 
311                 //assert candidate is neither default nor static
312                 assertFalse(Modifier.isStatic(method.getModifiers()));
313                 assertFalse(method.isDefault());
314                 break;
315             case ABSTRACT:
316                 //assert candidate is neither default nor static
317                 assertFalse(Modifier.isStatic(method.getModifiers()));
318                 assertFalse(method.isDefault());
319                 break;
320             default:
321                 fail(); //this should never happen
322                 break;
323         }
324 
325     }
326 
327 
isMethodDeclared(Class<?> type, String name)328     private boolean isMethodDeclared(Class<?> type, String name) {
329         MethodDesc[] methDescs = type.getAnnotationsByType(MethodDesc.class);
330         for (MethodDesc desc : methDescs) {
331             if (desc.declared() == YES && desc.name().equals(name))
332                 return true;
333         }
334         return false;
335     }
336 
337 
isFoundByGetMethod(Class<?> c, String method, Class<?>... argTypes)338     private boolean isFoundByGetMethod(Class<?> c, String method, Class<?>... argTypes) {
339         try {
340             c.getMethod(method, argTypes);
341             return true;
342         } catch (NoSuchMethodException notFound) {
343             return false;
344         }
345     }
346 
347 
argTypes(Object param)348     private Class<?>[] argTypes(Object param) {
349         return param == null ? new Class[0] : new Class[] { Object.class };
350     }
351 
352 
tryInvoke(Method m, Class<?> receiverType, Object param)353     private Object tryInvoke(Method m, Class<?> receiverType, Object param)
354             throws Exception {
355         Object receiver = receiverType == null ? null : receiverType.newInstance();
356         Object[] args = param == null ? new Object[0] : new Object[] { param };
357         return m.invoke(receiver, args);
358     }
359 
360 
getTestMH(Class clazz, String methodName, Object param)361     private MethodHandle getTestMH(Class clazz, String methodName, Object param)
362             throws Exception {
363         return getTestMH(clazz, methodName, param, false);
364     }
365 
366 
getTestMH(Class clazz, String methodName, Object param, boolean isNegativeTest)367     private MethodHandle getTestMH(Class clazz, String methodName,
368             Object param, boolean isNegativeTest)
369             throws Exception {
370         MethodType mType = (param != null)
371                 ? MethodType.genericMethodType(1)
372                 : MethodType.methodType(String.class);
373         MethodHandles.Lookup lookup = MethodHandles.lookup();
374         if (!isNegativeTest) {
375             return methodName.equals("staticMethod")
376                     ? lookup.findStatic(clazz, methodName, mType)
377                     : lookup.findVirtual(clazz, methodName, mType);
378         } else {
379             return methodName.equals("staticMethod")
380                     ? lookup.findVirtual(clazz, methodName, mType)
381                     : lookup.findStatic(clazz, methodName, mType);
382         }
383     }
384 }
385