1 /*
2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 import sun.security.util.HexDumpEncoder;
25 
26 import java.io.IOException;
27 import java.net.DatagramPacket;
28 import java.net.DatagramSocket;
29 import java.net.InetAddress;
30 import java.net.SocketException;
31 import java.nio.ByteBuffer;
32 import java.nio.file.Paths;
33 import java.util.ArrayList;
34 import java.util.Arrays;
35 import java.util.List;
36 import java.util.Scanner;
37 import java.util.regex.MatchResult;
38 
39 /*
40  * A dummy DNS server.
41  *
42  * Loads a sequence of DNS messages from a capture file into its cache.
43  * It listens for DNS UDP requests, finds match request in cache and sends the
44  * corresponding DNS responses.
45  *
46  * The capture file contains an DNS protocol exchange in the hexadecimal
47  * dump format emitted by HexDumpEncoder:
48  *
49  * xxxx: 00 11 22 33 44 55 66 77   88 99 aa bb cc dd ee ff  ................
50  *
51  * Typically, DNS protocol exchange is generated by DNSTracer who captures
52  * communication messages between DNS application program and real DNS server
53  */
54 public class DNSServer extends Thread implements Server {
55 
56     public class Pair<F, S> {
57         private F first;
58         private S second;
59 
Pair(F first, S second)60         public Pair(F first, S second) {
61             this.first = first;
62             this.second = second;
63         }
64 
setFirst(F first)65         public void setFirst(F first) {
66             this.first = first;
67         }
68 
setSecond(S second)69         public void setSecond(S second) {
70             this.second = second;
71         }
72 
getFirst()73         public F getFirst() {
74             return first;
75         }
76 
getSecond()77         public S getSecond() {
78             return second;
79         }
80     }
81 
82     public static final int DNS_HEADER_SIZE = 12;
83     public static final int DNS_PACKET_SIZE = 512;
84 
85     static HexDumpEncoder encoder = new HexDumpEncoder();
86 
87     private DatagramSocket socket;
88     private String filename;
89     private boolean loop;
90     private final List<Pair<byte[], byte[]>> cache = new ArrayList<>();
91     private ByteBuffer reqBuffer = ByteBuffer.allocate(DNS_PACKET_SIZE);
92     private volatile boolean isRunning;
93 
DNSServer(String filename)94     public DNSServer(String filename) throws SocketException {
95         this(filename, false);
96     }
97 
DNSServer(String filename, boolean loop)98     public DNSServer(String filename, boolean loop) throws SocketException {
99         this.socket = new DatagramSocket(0, InetAddress.getLoopbackAddress());
100         this.filename = filename;
101         this.loop = loop;
102     }
103 
run()104     public void run() {
105         try {
106             isRunning = true;
107             System.out.println(
108                     "DNSServer: Loading DNS cache data from : " + filename);
109             loadCaptureFile(filename);
110 
111             System.out.println(
112                     "DNSServer: listening on port " + socket.getLocalPort());
113 
114             System.out.println("DNSServer: loop playback: " + loop);
115 
116             int playbackIndex = 0;
117 
118             while (playbackIndex < cache.size()) {
119                 DatagramPacket reqPacket = receiveQuery();
120 
121                 if (!verifyRequestMsg(reqPacket, playbackIndex)) {
122                     if (playbackIndex > 0 && verifyRequestMsg(reqPacket,
123                             playbackIndex - 1)) {
124                         System.out.println(
125                                 "DNSServer: received retry query, resend");
126                         playbackIndex--;
127                     } else {
128                         throw new RuntimeException(
129                                 "DNSServer: Error: Failed to verify DNS request. "
130                                         + "Not identical request message : \n"
131                                         + encoder.encodeBuffer(
132                                         Arrays.copyOf(reqPacket.getData(),
133                                                 reqPacket.getLength())));
134                     }
135                 }
136 
137                 sendResponse(reqPacket, playbackIndex);
138 
139                 playbackIndex++;
140                 if (loop && playbackIndex >= cache.size()) {
141                     playbackIndex = 0;
142                 }
143             }
144 
145             System.out.println(
146                     "DNSServer: Done for all cached messages playback");
147 
148             System.out.println(
149                     "DNSServer: Still listening for possible retry query");
150             while (true) {
151                 DatagramPacket reqPacket = receiveQuery();
152 
153                 // here we only handle the retry query for last one
154                 if (!verifyRequestMsg(reqPacket, playbackIndex - 1)) {
155                     throw new RuntimeException(
156                             "DNSServer: Error: Failed to verify DNS request. "
157                                     + "Not identical request message : \n"
158                                     + encoder.encodeBuffer(
159                                     Arrays.copyOf(reqPacket.getData(),
160                                             reqPacket.getLength())));
161                 }
162 
163                 sendResponse(reqPacket, playbackIndex - 1);
164             }
165         } catch (Exception e) {
166             if (isRunning) {
167                 System.err.println("DNSServer: Error: " + e);
168                 e.printStackTrace();
169             } else {
170                 System.out.println("DNSServer: Exit");
171             }
172         }
173     }
174 
receiveQuery()175     private DatagramPacket receiveQuery() throws IOException {
176         DatagramPacket reqPacket = new DatagramPacket(reqBuffer.array(),
177                 reqBuffer.array().length);
178         socket.receive(reqPacket);
179 
180         System.out.println("DNSServer: received query message from " + reqPacket
181                 .getSocketAddress());
182 
183         return reqPacket;
184     }
185 
sendResponse(DatagramPacket reqPacket, int playbackIndex)186     private void sendResponse(DatagramPacket reqPacket, int playbackIndex)
187             throws IOException {
188         byte[] payload = generateResponsePayload(reqPacket, playbackIndex);
189         socket.send(new DatagramPacket(payload, payload.length,
190                 reqPacket.getSocketAddress()));
191         System.out.println("DNSServer: send response message to " + reqPacket
192                 .getSocketAddress());
193     }
194 
195     /*
196      * Load a capture file containing an DNS protocol exchange in the
197      * hexadecimal dump format emitted by sun.misc.HexDumpEncoder:
198      *
199      * xxxx: 00 11 22 33 44 55 66 77   88 99 aa bb cc dd ee ff  ................
200      */
loadCaptureFile(String filename)201     private void loadCaptureFile(String filename) throws IOException {
202         StringBuilder hexString = new StringBuilder();
203         String pattern = "(....): (..) (..) (..) (..) (..) (..) (..) (..)   "
204                 + "(..) (..) (..) (..) (..) (..) (..) (..).*";
205 
206         try (Scanner fileScanner = new Scanner(Paths.get(filename))) {
207             while (fileScanner.hasNextLine()) {
208 
209                 try (Scanner lineScanner = new Scanner(
210                         fileScanner.nextLine())) {
211                     if (lineScanner.findInLine(pattern) == null) {
212                         continue;
213                     }
214                     MatchResult result = lineScanner.match();
215                     for (int i = 1; i <= result.groupCount(); i++) {
216                         String digits = result.group(i);
217                         if (digits.length() == 4) {
218                             if (digits.equals("0000")) { // start-of-message
219                                 if (hexString.length() > 0) {
220                                     addToCache(hexString.toString());
221                                     hexString.delete(0, hexString.length());
222                                 }
223                             }
224                             continue;
225                         } else if (digits.equals("  ")) { // short message
226                             continue;
227                         }
228                         hexString.append(digits);
229                     }
230                 }
231             }
232         }
233         addToCache(hexString.toString());
234     }
235 
236     /*
237      * Add an DNS encoding to the cache (by request message key).
238      */
addToCache(String hexString)239     private void addToCache(String hexString) {
240         byte[] encoding = parseHexBinary(hexString);
241         if (encoding.length < DNS_HEADER_SIZE) {
242             throw new RuntimeException("Invalid DNS message : " + hexString);
243         }
244 
245         if (getQR(encoding) == 0) {
246             // a query message, create entry in cache
247             cache.add(new Pair<>(encoding, null));
248             System.out.println(
249                     "    adding DNS query message with ID " + getID(encoding)
250                             + " to the cache");
251         } else {
252             // a response message, attach it to the query entry
253             if (!cache.isEmpty() && (getID(getLatestCacheEntry().getFirst())
254                     == getID(encoding))) {
255                 getLatestCacheEntry().setSecond(encoding);
256                 System.out.println(
257                         "    adding DNS response message associated to ID "
258                                 + getID(encoding) + " in the cache");
259             } else {
260                 throw new RuntimeException(
261                         "Invalid DNS message : " + hexString);
262             }
263         }
264     }
265 
266     /*
267      * ID: A 16 bit identifier assigned by the program that generates any
268      * kind of query. This identifier is copied the corresponding reply and
269      * can be used by the requester to match up replies to outstanding queries.
270      */
getID(byte[] encoding)271     private static int getID(byte[] encoding) {
272         return ByteBuffer.wrap(encoding, 0, 2).getShort();
273     }
274 
275     /*
276      * QR: A one bit field that specifies whether this message is
277      * a query (0), or a response (1) after ID
278      */
getQR(byte[] encoding)279     private static int getQR(byte[] encoding) {
280         return encoding[2] & (0x01 << 7);
281     }
282 
getLatestCacheEntry()283     private Pair<byte[], byte[]> getLatestCacheEntry() {
284         return cache.get(cache.size() - 1);
285     }
286 
verifyRequestMsg(DatagramPacket packet, int playbackIndex)287     private boolean verifyRequestMsg(DatagramPacket packet, int playbackIndex) {
288         byte[] cachedRequest = cache.get(playbackIndex).getFirst();
289         return Arrays.equals(Arrays
290                         .copyOfRange(packet.getData(), 2, packet.getLength()),
291                 Arrays.copyOfRange(cachedRequest, 2, cachedRequest.length));
292     }
293 
generateResponsePayload(DatagramPacket packet, int playbackIndex)294     private byte[] generateResponsePayload(DatagramPacket packet,
295             int playbackIndex) {
296         byte[] resMsg = cache.get(playbackIndex).getSecond();
297         byte[] payload = Arrays.copyOf(resMsg, resMsg.length);
298 
299         // replace the ID with same with real request
300         payload[0] = packet.getData()[0];
301         payload[1] = packet.getData()[1];
302 
303         return payload;
304     }
305 
parseHexBinary(String s)306     public static byte[] parseHexBinary(String s) {
307 
308         final int len = s.length();
309 
310         // "111" is not a valid hex encoding.
311         if (len % 2 != 0) {
312             throw new IllegalArgumentException(
313                     "hexBinary needs to be even-length: " + s);
314         }
315 
316         byte[] out = new byte[len / 2];
317 
318         for (int i = 0; i < len; i += 2) {
319             int h = hexToBin(s.charAt(i));
320             int l = hexToBin(s.charAt(i + 1));
321             if (h == -1 || l == -1) {
322                 throw new IllegalArgumentException(
323                         "contains illegal character for hexBinary: " + s);
324             }
325 
326             out[i / 2] = (byte) (h * 16 + l);
327         }
328 
329         return out;
330     }
331 
hexToBin(char ch)332     private static int hexToBin(char ch) {
333         if ('0' <= ch && ch <= '9') {
334             return ch - '0';
335         }
336         if ('A' <= ch && ch <= 'F') {
337             return ch - 'A' + 10;
338         }
339         if ('a' <= ch && ch <= 'f') {
340             return ch - 'a' + 10;
341         }
342         return -1;
343     }
344 
stopServer()345     @Override public void stopServer() {
346         isRunning = false;
347         if (socket != null) {
348             try {
349                 socket.close();
350             } catch (Exception e) {
351                 // ignore
352             }
353         }
354     }
355 
getPort()356     @Override public int getPort() {
357         if (socket != null) {
358             return socket.getLocalPort();
359         } else {
360             return -1;
361         }
362     }
363 }
364