1 /*
2  * Copyright (c) 2013, 2017, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 /*
25  * @test
26  * @bug 8004970
27  * @summary Lambda serialization in the presence of class loaders
28  * @run main LambdaClassLoaderSerialization
29  * @author Peter Levart
30  */
31 
32 import java.io.ByteArrayInputStream;
33 import java.io.ByteArrayOutputStream;
34 import java.io.IOException;
35 import java.io.InputStream;
36 import java.io.ObjectInputStream;
37 import java.io.ObjectOutputStream;
38 import java.io.Serializable;
39 
40 public class LambdaClassLoaderSerialization {
41 
42     public interface SerializableRunnable extends Runnable, Serializable {}
43 
44     public static class MyCode implements SerializableRunnable {
45 
serialize(Object o)46         private byte[] serialize(Object o) {
47             ByteArrayOutputStream baos;
48             try (
49                 ObjectOutputStream oos =
50                     new ObjectOutputStream(baos = new ByteArrayOutputStream())
51             ) {
52                 oos.writeObject(o);
53             }
54             catch (IOException e) {
55                 throw new RuntimeException(e);
56             }
57             return baos.toByteArray();
58         }
59 
deserialize(byte[] bytes)60         private <T> T deserialize(byte[] bytes) {
61             try (
62                 ObjectInputStream ois =
63                     new ObjectInputStream(new ByteArrayInputStream(bytes))
64             ) {
65                 return (T) ois.readObject();
66             }
67             catch (IOException | ClassNotFoundException e) {
68                 throw new RuntimeException(e);
69             }
70         }
71 
72         @Override
run()73         public void run() {
74             System.out.println("                this: " + this);
75 
76             SerializableRunnable deSerializedThis = deserialize(serialize(this));
77             System.out.println("    deSerializedThis: " + deSerializedThis);
78 
79             SerializableRunnable runnable = () -> {System.out.println("HELLO");};
80             System.out.println("            runnable: " + runnable);
81 
82             SerializableRunnable deSerializedRunnable = deserialize(serialize(runnable));
83             System.out.println("deSerializedRunnable: " + deSerializedRunnable);
84         }
85     }
86 
main(String[] args)87     public static void main(String[] args) throws Exception {
88         ClassLoader myCl = new MyClassLoader(
89             LambdaClassLoaderSerialization.class.getClassLoader()
90         );
91         Class<?> myCodeClass = Class.forName(
92             LambdaClassLoaderSerialization.class.getName() + "$MyCode",
93             true,
94             myCl
95         );
96         Runnable myCode = (Runnable) myCodeClass.newInstance();
97         myCode.run();
98     }
99 
100     static class MyClassLoader extends ClassLoader {
MyClassLoader(ClassLoader parent)101         MyClassLoader(ClassLoader parent) {
102             super(parent);
103         }
104 
105         @Override
loadClass(String name, boolean resolve)106         protected Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
107             if (name.indexOf('.') < 0) {
108                 synchronized (getClassLoadingLock(name)) {
109                     Class<?> c = findLoadedClass(name);
110                     if (c == null) {
111                         c = findClass(name);
112                     }
113                     if (resolve) {
114                         resolveClass(c);
115                     }
116                     return c;
117                 }
118             } else {
119                 return super.loadClass(name, resolve);
120             }
121         }
122 
123         @Override
findClass(String name)124         protected Class<?> findClass(String name) throws ClassNotFoundException {
125             String path = name.replace('.', '/').concat(".class");
126             try (InputStream is = getResourceAsStream(path)) {
127                 if (is != null) {
128                     byte[] bytes = is.readAllBytes();
129                     return defineClass(name, bytes, 0, bytes.length);
130                 } else {
131                     throw new ClassNotFoundException(name);
132                 }
133             }
134             catch (IOException e) {
135                 throw new ClassNotFoundException(name, e);
136             }
137         }
138     }
139 }
140