1 /*
2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 /**
25  * @test
26  * @modules java.management java.base/java.io:+open java.base/java.net:+open
27  * @run main/othervm -Djava.net.preferIPv4Stack=true UnreferencedMulticastSockets
28  * @run main/othervm UnreferencedMulticastSockets
29  * @summary Check that unreferenced multicast sockets are closed
30  */
31 
32 import java.io.FileDescriptor;
33 import java.lang.management.ManagementFactory;
34 import java.lang.management.OperatingSystemMXBean;
35 import java.lang.ref.ReferenceQueue;
36 import java.lang.ref.WeakReference;
37 import java.lang.reflect.Field;
38 import java.io.IOException;
39 import java.net.DatagramPacket;
40 import java.net.DatagramSocket;
41 import java.net.DatagramSocketImpl;
42 import java.net.InetAddress;
43 import java.net.MulticastSocket;
44 import java.net.UnknownHostException;
45 import java.nio.file.Files;
46 import java.nio.file.Path;
47 import java.nio.file.Paths;
48 import java.util.ArrayDeque;
49 import java.util.List;
50 import java.util.Optional;
51 import java.util.concurrent.TimeUnit;
52 
53 import com.sun.management.UnixOperatingSystemMXBean;
54 
55 public class UnreferencedMulticastSockets {
56 
57     /**
58      * The set of sockets we have to check up on.
59      */
60     final static ArrayDeque<NamedWeak> pendingSockets = new ArrayDeque<>(5);
61 
62     /**
63      * Queued objects when they are unreferenced.
64      */
65     final static ReferenceQueue<Object> pendingQueue = new ReferenceQueue<>();
66 
67     // Server to echo a datagram packet
68     static class Server implements Runnable {
69 
70         MulticastSocket ss;
71 
Server()72         Server() throws IOException {
73             ss = new MulticastSocket(0);
74             System.out.printf("  DatagramServer addr: %s: %d%n",
75                     this.getHost(), this.getPort());
76             pendingSockets.add(new NamedWeak(ss, pendingQueue, "serverMulticastSocket"));
77             extractRefs(ss, "serverMulticastSocket");
78         }
79 
getHost()80         InetAddress getHost() throws UnknownHostException {
81             InetAddress localhost = InetAddress.getByName("localhost"); //.getLocalHost();
82             return localhost;
83         }
84 
getPort()85         int getPort() {
86             return ss.getLocalPort();
87         }
88 
89         // Receive a byte and send back a byte
run()90         public void run() {
91             try {
92                 byte[] buffer = new byte[50];
93                 DatagramPacket p = new DatagramPacket(buffer, buffer.length);
94                 ss.receive(p);
95                 buffer[0] += 1;
96                 ss.send(p);         // send back +1
97 
98                 // do NOT close but 'forget' the socket reference
99                 ss = null;
100             } catch (Exception ioe) {
101                 ioe.printStackTrace();
102             }
103         }
104     }
105 
main(String args[])106     public static void main(String args[]) throws Exception {
107 
108         // Create and close a MulticastSocket to warm up the FD count for side effects.
109         try (MulticastSocket s = new MulticastSocket(0)) {
110             // no-op; close immediately
111             s.getLocalPort();   // no-op
112         }
113 
114         long fdCount0 = getFdCount();
115         listProcFD();
116 
117         // start a server
118         Server svr = new Server();
119         Thread thr = new Thread(svr);
120         thr.start();
121 
122         MulticastSocket client = new MulticastSocket(0);
123         System.out.printf("  client bound port: %d%n", client.getLocalPort());
124         client.connect(svr.getHost(), svr.getPort());
125         pendingSockets.add(new NamedWeak(client, pendingQueue, "clientMulticastSocket"));
126         extractRefs(client, "clientMulticastSocket");
127 
128         byte[] msg = new byte[1];
129         msg[0] = 1;
130         DatagramPacket p = new DatagramPacket(msg, msg.length, svr.getHost(), svr.getPort());
131         client.send(p);
132 
133         p = new DatagramPacket(msg, msg.length);
134         client.receive(p);
135 
136         System.out.printf("echo received from: %s%n", p.getSocketAddress());
137         if (msg[0] != 2) {
138             throw new AssertionError("incorrect data received: expected: 2, actual: " + msg[0]);
139         }
140 
141         // Do NOT close the MulticastSocket; forget it
142 
143         Object ref;
144         int loops = 20;
145         while (!pendingSockets.isEmpty() && loops-- > 0) {
146             ref = pendingQueue.remove(1000L);
147             if (ref != null) {
148                 pendingSockets.remove(ref);
149                 System.out.printf("  ref freed: %s, remaining: %d%n", ref, pendingSockets.size());
150             } else {
151                 client = null;
152                 p = null;
153                 msg = null;
154                 System.gc();
155             }
156         }
157 
158         thr.join();
159 
160         // List the open file descriptors
161         long fdCount = getFdCount();
162         System.out.printf("Initial fdCount: %d, final fdCount: %d%n", fdCount0, fdCount);
163         listProcFD();
164 
165         if (loops == 0) {
166             throw new AssertionError("Not all references reclaimed");
167         }
168     }
169 
170     // Get the count of open file descriptors, or -1 if not available
getFdCount()171     private static long getFdCount() {
172         OperatingSystemMXBean mxBean = ManagementFactory.getOperatingSystemMXBean();
173         return (mxBean instanceof UnixOperatingSystemMXBean)
174                 ? ((UnixOperatingSystemMXBean) mxBean).getOpenFileDescriptorCount()
175                 : -1L;
176     }
177 
178     // Reflect to find references in the socket implementation that will be gc'd
extractRefs(MulticastSocket s, String name)179     private static void extractRefs(MulticastSocket s, String name) {
180         try {
181 
182             Field socketImplField = DatagramSocket.class.getDeclaredField("impl");
183             socketImplField.setAccessible(true);
184             Object socketImpl = socketImplField.get(s);
185 
186             Field fileDescriptorField = DatagramSocketImpl.class.getDeclaredField("fd");
187             fileDescriptorField.setAccessible(true);
188             FileDescriptor fileDescriptor = (FileDescriptor) fileDescriptorField.get(socketImpl);
189             extractRefs(fileDescriptor, name);
190 
191             Class<?> socketImplClass = socketImpl.getClass();
192             System.out.printf("socketImplClass: %s%n", socketImplClass);
193             if (socketImplClass.getName().equals("java.net.TwoStacksPlainDatagramSocketImpl")) {
194                 Field fileDescriptor1Field = socketImplClass.getDeclaredField("fd1");
195                 fileDescriptor1Field.setAccessible(true);
196                 FileDescriptor fileDescriptor1 = (FileDescriptor) fileDescriptor1Field.get(socketImpl);
197                 extractRefs(fileDescriptor1, name + "::twoStacksFd1");
198 
199             } else {
200                 System.out.printf("socketImpl class name not matched: %s != %s%n",
201                         socketImplClass.getName(), "java.net.TwoStacksPlainDatagramSocketImpl");
202             }
203         } catch (NoSuchFieldException | IllegalAccessException ex) {
204             ex.printStackTrace();
205             throw new AssertionError("missing field", ex);
206         }
207     }
208 
extractRefs(FileDescriptor fileDescriptor, String name)209     private static void extractRefs(FileDescriptor fileDescriptor, String name) {
210         Object cleanup = null;
211         int rawfd = -1;
212         try {
213             if (fileDescriptor != null) {
214                 Field fd1Field = FileDescriptor.class.getDeclaredField("fd");
215                 fd1Field.setAccessible(true);
216                 rawfd = fd1Field.getInt(fileDescriptor);
217 
218                 Field cleanupfdField = FileDescriptor.class.getDeclaredField("cleanup");
219                 cleanupfdField.setAccessible(true);
220                 cleanup = cleanupfdField.get(fileDescriptor);
221                 pendingSockets.add(new NamedWeak(fileDescriptor, pendingQueue,
222                         name + "::fileDescriptor: " + rawfd));
223                 pendingSockets.add(new NamedWeak(cleanup, pendingQueue, name + "::fdCleanup: " + rawfd));
224 
225             }
226         } catch (NoSuchFieldException | IllegalAccessException ex) {
227             ex.printStackTrace();
228             throw new AssertionError("missing field", ex);
229         } finally {
230             System.out.print(String.format("  %s:: fd: %s, fd: %d, cleanup: %s%n",
231                     name, fileDescriptor, rawfd, cleanup));
232         }
233     }
234 
235     /**
236      * Method to list the open file descriptors (if supported by the 'lsof' command).
237      */
listProcFD()238     static void listProcFD() {
239         List<String> lsofDirs = List.of("/usr/bin", "/usr/sbin");
240         Optional<Path> lsof = lsofDirs.stream()
241                 .map(s -> Paths.get(s, "lsof"))
242                 .filter(f -> Files.isExecutable(f))
243                 .findFirst();
244         lsof.ifPresent(exe -> {
245             try {
246                 System.out.printf("Open File Descriptors:%n");
247                 long pid = ProcessHandle.current().pid();
248                 ProcessBuilder pb = new ProcessBuilder(exe.toString(), "-p", Integer.toString((int) pid));
249                 pb.inheritIO();
250                 Process p = pb.start();
251                 p.waitFor(10, TimeUnit.SECONDS);
252             } catch (IOException | InterruptedException ie) {
253                 ie.printStackTrace();
254             }
255         });
256     }
257 
258 
259     // Simple class to identify which refs have been queued
260     static class NamedWeak extends WeakReference<Object> {
261         private final String name;
262 
NamedWeak(Object o, ReferenceQueue<Object> queue, String name)263         NamedWeak(Object o, ReferenceQueue<Object> queue, String name) {
264             super(o, queue);
265             this.name = name;
266         }
267 
toString()268         public String toString() {
269             return name + "; " + super.toString();
270         }
271     }
272 }
273