1 /*
2  * Copyright (c) 2017, 2019, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 /*
25  * @test
26  * @summary Smoke test for JDWP hardening
27  * @library /test/lib
28  * @run driver JdwpAllowTest
29  */
30 
31 import java.io.IOException;
32 
33 import java.net.InetAddress;
34 import java.net.Socket;
35 import java.net.SocketException;
36 
37 import jdk.test.lib.Utils;
38 import jdk.test.lib.apps.LingeredApp;
39 
40 import java.util.ArrayList;
41 import java.util.LinkedList;
42 import java.util.List;
43 import java.util.Random;
44 import java.util.concurrent.TimeUnit;
45 import java.util.regex.Matcher;
46 import java.util.regex.Pattern;
47 
48 
49 public class JdwpAllowTest {
50 
handshake(int port)51     public static int handshake(int port) throws IOException {
52         // Connect to the debuggee and handshake
53         int res = -1;
54         Socket s = null;
55         try {
56             s = new Socket(localAddr, port);
57             s.getOutputStream().write("JDWP-Handshake".getBytes("UTF-8"));
58             byte[] buffer = new byte[24];
59             res = s.getInputStream().read(buffer);
60         }
61         catch (SocketException ex) {
62             ex.printStackTrace();
63             // pass
64         } finally {
65             if (s != null) {
66                 s.close();
67             }
68         }
69         return res;
70     }
71 
prepareCmd(String allowOpt)72     public static ArrayList<String> prepareCmd(String allowOpt) {
73          ArrayList<String> cmd = new ArrayList<>();
74 
75          String jdwpArgs = "-agentlib:jdwp=transport=dt_socket,server=y," +
76                            "suspend=n,address=*:0"
77                             + (allowOpt == null ? "" : ",allow=" + allowOpt);
78          cmd.add(jdwpArgs);
79          return cmd;
80     }
81 
82     private static Pattern listenRegexp = Pattern.compile("Listening for transport \\b(.+)\\b at address: \\b(\\d+)\\b");
detectPort(LingeredApp app)83     private static int detectPort(LingeredApp app) {
84         long maxWaitTime = System.currentTimeMillis()
85                 + Utils.adjustTimeout(10000);  // 10 seconds adjusted for TIMEOUT_FACTOR
86         while (true) {
87             String s = app.getProcessStdout();
88             Matcher m = listenRegexp.matcher(s);
89             if (m.find()) {
90                 // m.group(1) is transport, m.group(2) is port
91                 return Integer.parseInt(m.group(2));
92             }
93             if (System.currentTimeMillis() > maxWaitTime) {
94                 throw new RuntimeException("Could not detect port from '" + s + "' (timeout)");
95             }
96             try {
97                 if (app.getProcess().waitFor(500, TimeUnit.MILLISECONDS)) {
98                     throw new RuntimeException("Could not detect port from '" + s + "' (debuggee is terminated)");
99                 }
100             } catch (InterruptedException e) {
101                 // ignore
102             }
103         }
104     }
105 
positiveTest(String testName, String allowOpt)106     public static void positiveTest(String testName, String allowOpt)
107         throws InterruptedException, IOException {
108         System.err.println("\nStarting " + testName);
109         ArrayList<String> cmd = prepareCmd(allowOpt);
110 
111         LingeredApp a = LingeredApp.startApp(cmd);
112         int res;
113         try {
114             res = handshake(detectPort(a));
115         } finally {
116             a.stopApp();
117         }
118         if (res < 0) {
119             throw new RuntimeException(testName + " FAILED");
120         }
121         System.err.println(testName + " PASSED");
122     }
123 
negativeTest(String testName, String allowOpt)124     public static void negativeTest(String testName, String allowOpt)
125         throws InterruptedException, IOException {
126         System.err.println("\nStarting " + testName);
127         ArrayList<String> cmd = prepareCmd(allowOpt);
128 
129         LingeredApp a = LingeredApp.startApp(cmd);
130         int res;
131         try {
132             res = handshake(detectPort(a));
133         } finally {
134             a.stopApp();
135         }
136         if (res > 0) {
137             System.err.println(testName + ": res=" + res);
138             throw new RuntimeException(testName + " FAILED");
139         }
140         System.err.println(testName + ": returned a negative code as expected: " + res);
141         System.err.println(testName + " PASSED");
142     }
143 
badAllowOptionTest(String testName, String allowOpt)144     public static void badAllowOptionTest(String testName, String allowOpt)
145         throws InterruptedException, IOException {
146         System.err.println("\nStarting " + testName);
147         ArrayList<String> cmd = prepareCmd(allowOpt);
148 
149         LingeredApp a;
150         try {
151             a = LingeredApp.startApp(cmd);
152         } catch (IOException ex) {
153             System.err.println(testName + ": caught expected IOException");
154             System.err.println(testName + " PASSED");
155             return;
156         }
157         // LingeredApp.startApp is expected to fail, but if not, terminate the app
158         a.stopApp();
159         throw new RuntimeException(testName + " FAILED");
160     }
161 
162     /*
163      * Generate allow address by changing random bit in the local address
164      * and calculate 2 masks (prefix length) - one is matches original local address
165      * and another doesn't.
166      */
167     private static class MaskTest {
168         public final String localAddress;
169         public final String allowAddress;
170         public final int prefixLengthGood;
171         public final int prefixLengthBad;
172 
MaskTest(InetAddress addr)173         public MaskTest(InetAddress addr) throws Exception {
174             localAddress = addr.getHostAddress();
175             byte[] bytes = addr.getAddress();
176             Random r = new Random();
177             // prefix length must be >= 1, so bitToChange must be >= 2
178             int bitToChange = r.nextInt(bytes.length * 8 - 3) + 2;
179             setBit(bytes, bitToChange, !getBit(bytes, bitToChange));
180             // clear rest of the bits for mask address
181             for (int i = bitToChange + 1; i < bytes.length * 8; i++) {
182                 setBit(bytes, i, false);
183             }
184             allowAddress = InetAddress.getByAddress(bytes).getHostAddress();
185 
186             prefixLengthBad = bitToChange;
187             prefixLengthGood = bitToChange - 1;
188         }
189 
getBit(byte[] bytes, int pos)190         private static boolean getBit(byte[] bytes, int pos) {
191             return (bytes[pos / 8] & (1 << (7 - (pos % 8)))) != 0;
192         }
193 
setBit(byte[] bytes, int pos, boolean value)194         private static void setBit(byte[] bytes, int pos, boolean value) {
195             byte byteValue = (byte)(1 << (7 - (pos % 8)));
196             if (value) {
197                 bytes[pos / 8] = (byte)(bytes[pos / 8] | byteValue);
198             } else {
199                 bytes[pos / 8] &= (~byteValue);
200             }
201         }
202     }
203 
204     private static String localAddr;
205     private static List<MaskTest> maskTests = new LinkedList<>();
206 
init()207     private static void init() throws Exception {
208         InetAddress addrs[] = InetAddress.getAllByName("localhost");
209         if (addrs.length == 0) {
210             throw new RuntimeException("No addresses is returned for 'localhost'");
211         }
212         localAddr = addrs[0].getHostAddress();
213         System.err.println("localhost address: " + localAddr);
214 
215         for (int i =  0; i < addrs.length; i++) {
216             maskTests.add(new MaskTest(addrs[i]));
217         }
218     }
219 
main(String[] args)220     public static void main(String[] args) throws Exception {
221         init();
222 
223         // No allow option is the same as the allow option ',allow=*' is passed
224         positiveTest("DefaultTest", null);
225 
226         // Explicit permission for connections from everywhere
227         positiveTest("ExplicitDefaultTest", "*");
228 
229         positiveTest("AllowTest", localAddr);
230 
231         positiveTest("MultiAllowTest", localAddr + "+10.0.0.0/8+172.16.0.0/12+192.168.0.0/24");
232 
233         // Bad allow address
234         negativeTest("DenyTest", "0.0.0.0");
235 
236         // Wrong separator ';' is used for allow option
237         badAllowOptionTest("MultiDenyTest", localAddr + ";192.168.0.0/24");
238 
239         // Empty allow option
240         badAllowOptionTest("EmptyAllowOptionTest", "");
241 
242         // Bad mix of allow option '*' with address value
243         badAllowOptionTest("ExplicitMultiDefault1Test", "*+" + localAddr);
244 
245         // Bad mix of allow address value with '*'
246         badAllowOptionTest("ExplicitMultiDefault2Test", localAddr + "+*");
247 
248         for (MaskTest test: maskTests) {
249             // override localAddr (to connect to required IPv4 or IPv6 address)
250             localAddr = test.localAddress;
251             positiveTest("PositiveMaskTest(" + test.localAddress + ")",
252                          test.allowAddress + "/" + test.prefixLengthGood);
253             positiveTest("NegativeMaskTest(" + test.localAddress + ")",
254                          test.allowAddress + "/" + test.prefixLengthBad);
255         }
256 
257         System.err.println("\nTest PASSED");
258     }
259 
260 }
261