1 /*
2  * Copyright (c) 1996, 2018, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.  Oracle designates this
8  * particular file as subject to the "Classpath" exception as provided
9  * by Oracle in the LICENSE file that accompanied this code.
10  *
11  * This code is distributed in the hope that it will be useful, but WITHOUT
12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
14  * version 2 for more details (a copy is included in the LICENSE file that
15  * accompanied this code).
16  *
17  * You should have received a copy of the GNU General Public License version
18  * 2 along with this work; if not, write to the Free Software Foundation,
19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20  *
21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22  * or visit www.oracle.com if you need additional information or have any
23  * questions.
24  */
25 
26 package sun.security.ssl;
27 
28 import java.io.*;
29 import java.nio.*;
30 import java.util.*;
31 import javax.net.ssl.*;
32 import sun.security.ssl.SSLCipher.SSLWriteCipher;
33 
34 /**
35  * DTLS {@code OutputRecord} implementation for {@code SSLEngine}.
36  */
37 final class DTLSOutputRecord extends OutputRecord implements DTLSRecord {
38 
39     private DTLSFragmenter fragmenter = null;
40 
41     int                 writeEpoch;
42 
43     int                 prevWriteEpoch;
44     Authenticator       prevWriteAuthenticator;
45     SSLWriteCipher      prevWriteCipher;
46 
47     private volatile boolean isCloseWaiting = false;
48 
DTLSOutputRecord(HandshakeHash handshakeHash)49     DTLSOutputRecord(HandshakeHash handshakeHash) {
50         super(handshakeHash, SSLWriteCipher.nullDTlsWriteCipher());
51 
52         this.writeEpoch = 0;
53         this.prevWriteEpoch = 0;
54         this.prevWriteCipher = SSLWriteCipher.nullDTlsWriteCipher();
55 
56         this.packetSize = DTLSRecord.maxRecordSize;
57         this.protocolVersion = ProtocolVersion.NONE;
58     }
59 
60     @Override
close()61     public synchronized void close() throws IOException {
62         if (!isClosed) {
63             if (fragmenter != null && fragmenter.hasAlert()) {
64                 isCloseWaiting = true;
65             } else {
66                 super.close();
67             }
68         }
69     }
70 
isClosed()71     boolean isClosed() {
72         return isClosed || isCloseWaiting;
73     }
74 
75     @Override
initHandshaker()76     void initHandshaker() {
77         // clean up
78         fragmenter = null;
79     }
80 
81     @Override
finishHandshake()82     void finishHandshake() {
83         // Nothing to do here currently.
84     }
85 
86     @Override
changeWriteCiphers(SSLWriteCipher writeCipher, boolean useChangeCipherSpec)87     void changeWriteCiphers(SSLWriteCipher writeCipher,
88             boolean useChangeCipherSpec) throws IOException {
89         if (isClosed()) {
90             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
91                 SSLLogger.warning("outbound has closed, ignore outbound " +
92                     "change_cipher_spec message");
93             }
94             return;
95         }
96 
97         if (useChangeCipherSpec) {
98             encodeChangeCipherSpec();
99         }
100 
101         prevWriteCipher.dispose();
102 
103         this.prevWriteCipher = this.writeCipher;
104         this.prevWriteEpoch = this.writeEpoch;
105 
106         this.writeCipher = writeCipher;
107         this.writeEpoch++;
108 
109         this.isFirstAppOutputRecord = true;
110 
111         // set the epoch number
112         this.writeCipher.authenticator.setEpochNumber(this.writeEpoch);
113     }
114 
115     @Override
encodeAlert(byte level, byte description)116     void encodeAlert(byte level, byte description) throws IOException {
117         if (isClosed()) {
118             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
119                 SSLLogger.warning("outbound has closed, ignore outbound " +
120                     "alert message: " + Alert.nameOf(description));
121             }
122             return;
123         }
124 
125         if (fragmenter == null) {
126            fragmenter = new DTLSFragmenter();
127         }
128 
129         fragmenter.queueUpAlert(level, description);
130     }
131 
132     @Override
encodeChangeCipherSpec()133     void encodeChangeCipherSpec() throws IOException {
134         if (isClosed()) {
135             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
136                 SSLLogger.warning("outbound has closed, ignore outbound " +
137                     "change_cipher_spec message");
138             }
139             return;
140         }
141 
142         if (fragmenter == null) {
143            fragmenter = new DTLSFragmenter();
144         }
145         fragmenter.queueUpChangeCipherSpec();
146     }
147 
148     @Override
encodeHandshake(byte[] source, int offset, int length)149     void encodeHandshake(byte[] source,
150             int offset, int length) throws IOException {
151         if (isClosed()) {
152             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
153                 SSLLogger.warning("outbound has closed, ignore outbound " +
154                         "handshake message",
155                         ByteBuffer.wrap(source, offset, length));
156             }
157             return;
158         }
159 
160         if (firstMessage) {
161             firstMessage = false;
162         }
163 
164         if (fragmenter == null) {
165            fragmenter = new DTLSFragmenter();
166         }
167 
168         fragmenter.queueUpHandshake(source, offset, length);
169     }
170 
171     @Override
encode( ByteBuffer[] srcs, int srcsOffset, int srcsLength, ByteBuffer[] dsts, int dstsOffset, int dstsLength)172     Ciphertext encode(
173         ByteBuffer[] srcs, int srcsOffset, int srcsLength,
174         ByteBuffer[] dsts, int dstsOffset, int dstsLength) throws IOException {
175 
176         if (isClosed) {
177             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
178                 SSLLogger.warning("outbound has closed, ignore outbound " +
179                     "application data or cached messages");
180             }
181 
182             return null;
183         } else if (isCloseWaiting) {
184             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
185                 SSLLogger.warning("outbound has closed, ignore outbound " +
186                     "application data");
187             }
188 
189             srcs = null;    // use no application data.
190         }
191 
192         return encode(srcs, srcsOffset, srcsLength, dsts[0]);
193     }
194 
encode(ByteBuffer[] sources, int offset, int length, ByteBuffer destination)195     private Ciphertext encode(ByteBuffer[] sources, int offset, int length,
196             ByteBuffer destination) throws IOException {
197 
198         if (writeCipher.authenticator.seqNumOverflow()) {
199             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
200                 SSLLogger.fine(
201                     "sequence number extremely close to overflow " +
202                     "(2^64-1 packets). Closing connection.");
203             }
204 
205             throw new SSLHandshakeException("sequence number overflow");
206         }
207 
208         // Don't process the incoming record until all of the buffered records
209         // get handled.  May need retransmission if no sources specified.
210         if (!isEmpty() || sources == null || sources.length == 0) {
211             Ciphertext ct = acquireCiphertext(destination);
212             if (ct != null) {
213                 return ct;
214             }
215         }
216 
217         if (sources == null || sources.length == 0) {
218             return null;
219         }
220 
221         int srcsRemains = 0;
222         for (int i = offset; i < offset + length; i++) {
223             srcsRemains += sources[i].remaining();
224         }
225 
226         if (srcsRemains == 0) {
227             return null;
228         }
229 
230         // not apply to handshake message
231         int fragLen;
232         if (packetSize > 0) {
233             fragLen = Math.min(maxRecordSize, packetSize);
234             fragLen = writeCipher.calculateFragmentSize(
235                     fragLen, headerSize);
236 
237             fragLen = Math.min(fragLen, Record.maxDataSize);
238         } else {
239             fragLen = Record.maxDataSize;
240         }
241 
242         if (fragmentSize > 0) {
243             fragLen = Math.min(fragLen, fragmentSize);
244         }
245 
246         int dstPos = destination.position();
247         int dstLim = destination.limit();
248         int dstContent = dstPos + headerSize +
249                                 writeCipher.getExplicitNonceSize();
250         destination.position(dstContent);
251 
252         int remains = Math.min(fragLen, destination.remaining());
253         fragLen = 0;
254         int srcsLen = offset + length;
255         for (int i = offset; (i < srcsLen) && (remains > 0); i++) {
256             int amount = Math.min(sources[i].remaining(), remains);
257             int srcLimit = sources[i].limit();
258             sources[i].limit(sources[i].position() + amount);
259             destination.put(sources[i]);
260             sources[i].limit(srcLimit);         // restore the limit
261             remains -= amount;
262             fragLen += amount;
263         }
264 
265         destination.limit(destination.position());
266         destination.position(dstContent);
267 
268         if (SSLLogger.isOn && SSLLogger.isOn("record")) {
269             SSLLogger.fine(
270                     "WRITE: " + protocolVersion + " " +
271                     ContentType.APPLICATION_DATA.name +
272                     ", length = " + destination.remaining());
273         }
274 
275         // Encrypt the fragment and wrap up a record.
276         long recordSN = encrypt(writeCipher,
277                 ContentType.APPLICATION_DATA.id, destination,
278                 dstPos, dstLim, headerSize,
279                 protocolVersion);
280 
281         if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
282             ByteBuffer temporary = destination.duplicate();
283             temporary.limit(temporary.position());
284             temporary.position(dstPos);
285             SSLLogger.fine("Raw write", temporary);
286         }
287 
288         // remain the limit unchanged
289         destination.limit(dstLim);
290 
291         return new Ciphertext(ContentType.APPLICATION_DATA.id,
292                 SSLHandshake.NOT_APPLICABLE.id, recordSN);
293     }
294 
acquireCiphertext( ByteBuffer destination)295     private Ciphertext acquireCiphertext(
296             ByteBuffer destination) throws IOException {
297         if (fragmenter != null) {
298             return fragmenter.acquireCiphertext(destination);
299         }
300 
301         return null;
302     }
303 
304     @Override
isEmpty()305     boolean isEmpty() {
306         return (fragmenter == null) || fragmenter.isEmpty();
307     }
308 
309     @Override
launchRetransmission()310     void launchRetransmission() {
311         // Note: Please don't retransmit if there are handshake messages
312         // or alerts waiting in the queue.
313         if ((fragmenter != null) && fragmenter.isRetransmittable()) {
314             fragmenter.setRetransmission();
315         }
316     }
317 
318     // buffered record fragment
319     private static class RecordMemo {
320         byte            contentType;
321         byte            majorVersion;
322         byte            minorVersion;
323         int             encodeEpoch;
324         SSLWriteCipher  encodeCipher;
325 
326         byte[]          fragment;
327     }
328 
329     private static class HandshakeMemo extends RecordMemo {
330         byte            handshakeType;
331         int             messageSequence;
332         int             acquireOffset;
333     }
334 
335     private final class DTLSFragmenter {
336         private final LinkedList<RecordMemo> handshakeMemos =
337                 new LinkedList<>();
338         private int acquireIndex = 0;
339         private int messageSequence = 0;
340         private boolean flightIsReady = false;
341 
342         // Per section 4.1.1, RFC 6347:
343         //
344         // If repeated retransmissions do not result in a response, and the
345         // PMTU is unknown, subsequent retransmissions SHOULD back off to a
346         // smaller record size, fragmenting the handshake message as
347         // appropriate.
348         //
349         // In this implementation, two times of retransmits would be attempted
350         // before backing off.  The back off is supported only if the packet
351         // size is bigger than 256 bytes.
352         private int retransmits = 2;            // attemps of retransmits
353 
queueUpHandshake(byte[] buf, int offset, int length)354         void queueUpHandshake(byte[] buf,
355                 int offset, int length) throws IOException {
356 
357             // Cleanup if a new flight starts.
358             if (flightIsReady) {
359                 handshakeMemos.clear();
360                 acquireIndex = 0;
361                 flightIsReady = false;
362             }
363 
364             HandshakeMemo memo = new HandshakeMemo();
365 
366             memo.contentType = ContentType.HANDSHAKE.id;
367             memo.majorVersion = protocolVersion.major;
368             memo.minorVersion = protocolVersion.minor;
369             memo.encodeEpoch = writeEpoch;
370             memo.encodeCipher = writeCipher;
371 
372             memo.handshakeType = buf[offset];
373             memo.messageSequence = messageSequence++;
374             memo.acquireOffset = 0;
375             memo.fragment = new byte[length - 4];       // 4: header size
376                                                         //    1: HandshakeType
377                                                         //    3: message length
378             System.arraycopy(buf, offset + 4, memo.fragment, 0, length - 4);
379 
380             handshakeHashing(memo, memo.fragment);
381             handshakeMemos.add(memo);
382 
383             if ((memo.handshakeType == SSLHandshake.CLIENT_HELLO.id) ||
384                 (memo.handshakeType == SSLHandshake.HELLO_REQUEST.id) ||
385                 (memo.handshakeType ==
386                         SSLHandshake.HELLO_VERIFY_REQUEST.id) ||
387                 (memo.handshakeType == SSLHandshake.SERVER_HELLO_DONE.id) ||
388                 (memo.handshakeType == SSLHandshake.FINISHED.id)) {
389 
390                 flightIsReady = true;
391             }
392         }
393 
queueUpChangeCipherSpec()394         void queueUpChangeCipherSpec() {
395 
396             // Cleanup if a new flight starts.
397             if (flightIsReady) {
398                 handshakeMemos.clear();
399                 acquireIndex = 0;
400                 flightIsReady = false;
401             }
402 
403             RecordMemo memo = new RecordMemo();
404 
405             memo.contentType = ContentType.CHANGE_CIPHER_SPEC.id;
406             memo.majorVersion = protocolVersion.major;
407             memo.minorVersion = protocolVersion.minor;
408             memo.encodeEpoch = writeEpoch;
409             memo.encodeCipher = writeCipher;
410 
411             memo.fragment = new byte[1];
412             memo.fragment[0] = 1;
413 
414             handshakeMemos.add(memo);
415         }
416 
queueUpAlert(byte level, byte description)417         void queueUpAlert(byte level, byte description) throws IOException {
418             RecordMemo memo = new RecordMemo();
419 
420             memo.contentType = ContentType.ALERT.id;
421             memo.majorVersion = protocolVersion.major;
422             memo.minorVersion = protocolVersion.minor;
423             memo.encodeEpoch = writeEpoch;
424             memo.encodeCipher = writeCipher;
425 
426             memo.fragment = new byte[2];
427             memo.fragment[0] = level;
428             memo.fragment[1] = description;
429 
430             handshakeMemos.add(memo);
431         }
432 
acquireCiphertext(ByteBuffer dstBuf)433         Ciphertext acquireCiphertext(ByteBuffer dstBuf) throws IOException {
434             if (isEmpty()) {
435                 if (isRetransmittable()) {
436                     setRetransmission();    // configure for retransmission
437                 } else {
438                     return null;
439                 }
440             }
441 
442             RecordMemo memo = handshakeMemos.get(acquireIndex);
443             HandshakeMemo hsMemo = null;
444             if (memo.contentType == ContentType.HANDSHAKE.id) {
445                 hsMemo = (HandshakeMemo)memo;
446             }
447 
448             // ChangeCipherSpec message is pretty small.  Don't worry about
449             // the fragmentation of ChangeCipherSpec record.
450             int fragLen;
451             if (packetSize > 0) {
452                 fragLen = Math.min(maxRecordSize, packetSize);
453                 fragLen = memo.encodeCipher.calculateFragmentSize(
454                         fragLen, 25);   // 25: header size
455                                                 //   13: DTLS record
456                                                 //   12: DTLS handshake message
457                 fragLen = Math.min(fragLen, Record.maxDataSize);
458             } else {
459                 fragLen = Record.maxDataSize;
460             }
461 
462             if (fragmentSize > 0) {
463                 fragLen = Math.min(fragLen, fragmentSize);
464             }
465 
466             int dstPos = dstBuf.position();
467             int dstLim = dstBuf.limit();
468             int dstContent = dstPos + headerSize +
469                                     memo.encodeCipher.getExplicitNonceSize();
470             dstBuf.position(dstContent);
471 
472             if (hsMemo != null) {
473                 fragLen = Math.min(fragLen,
474                         (hsMemo.fragment.length - hsMemo.acquireOffset));
475 
476                 dstBuf.put(hsMemo.handshakeType);
477                 dstBuf.put((byte)((hsMemo.fragment.length >> 16) & 0xFF));
478                 dstBuf.put((byte)((hsMemo.fragment.length >> 8) & 0xFF));
479                 dstBuf.put((byte)(hsMemo.fragment.length & 0xFF));
480                 dstBuf.put((byte)((hsMemo.messageSequence >> 8) & 0xFF));
481                 dstBuf.put((byte)(hsMemo.messageSequence & 0xFF));
482                 dstBuf.put((byte)((hsMemo.acquireOffset >> 16) & 0xFF));
483                 dstBuf.put((byte)((hsMemo.acquireOffset >> 8) & 0xFF));
484                 dstBuf.put((byte)(hsMemo.acquireOffset & 0xFF));
485                 dstBuf.put((byte)((fragLen >> 16) & 0xFF));
486                 dstBuf.put((byte)((fragLen >> 8) & 0xFF));
487                 dstBuf.put((byte)(fragLen & 0xFF));
488                 dstBuf.put(hsMemo.fragment, hsMemo.acquireOffset, fragLen);
489             } else {
490                 fragLen = Math.min(fragLen, memo.fragment.length);
491                 dstBuf.put(memo.fragment, 0, fragLen);
492             }
493 
494             dstBuf.limit(dstBuf.position());
495             dstBuf.position(dstContent);
496 
497             if (SSLLogger.isOn && SSLLogger.isOn("record")) {
498                 SSLLogger.fine(
499                         "WRITE: " + protocolVersion + " " +
500                         ContentType.nameOf(memo.contentType) +
501                         ", length = " + dstBuf.remaining());
502             }
503 
504             // Encrypt the fragment and wrap up a record.
505             long recordSN = encrypt(memo.encodeCipher,
506                     memo.contentType, dstBuf,
507                     dstPos, dstLim, headerSize,
508                     ProtocolVersion.valueOf(memo.majorVersion,
509                             memo.minorVersion));
510 
511             if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
512                 ByteBuffer temporary = dstBuf.duplicate();
513                 temporary.limit(temporary.position());
514                 temporary.position(dstPos);
515                 SSLLogger.fine(
516                         "Raw write (" + temporary.remaining() + ")", temporary);
517             }
518 
519             // remain the limit unchanged
520             dstBuf.limit(dstLim);
521 
522             // Reset the fragmentation offset.
523             if (hsMemo != null) {
524                 hsMemo.acquireOffset += fragLen;
525                 if (hsMemo.acquireOffset == hsMemo.fragment.length) {
526                     acquireIndex++;
527                 }
528 
529                 return new Ciphertext(hsMemo.contentType,
530                         hsMemo.handshakeType, recordSN);
531             } else {
532                 if (isCloseWaiting &&
533                         memo.contentType == ContentType.ALERT.id) {
534                     close();
535                 }
536 
537                 acquireIndex++;
538                 return new Ciphertext(memo.contentType,
539                         SSLHandshake.NOT_APPLICABLE.id, recordSN);
540             }
541         }
542 
handshakeHashing(HandshakeMemo hsFrag, byte[] hsBody)543         private void handshakeHashing(HandshakeMemo hsFrag, byte[] hsBody) {
544 
545             byte hsType = hsFrag.handshakeType;
546             if (!handshakeHash.isHashable(hsType)) {
547                 // omitted from handshake hash computation
548                 return;
549             }
550 
551             // calculate the DTLS header
552             byte[] temporary = new byte[12];    // 12: handshake header size
553 
554             // Handshake.msg_type
555             temporary[0] = hsFrag.handshakeType;
556 
557             // Handshake.length
558             temporary[1] = (byte)((hsBody.length >> 16) & 0xFF);
559             temporary[2] = (byte)((hsBody.length >> 8) & 0xFF);
560             temporary[3] = (byte)(hsBody.length & 0xFF);
561 
562             // Handshake.message_seq
563             temporary[4] = (byte)((hsFrag.messageSequence >> 8) & 0xFF);
564             temporary[5] = (byte)(hsFrag.messageSequence & 0xFF);
565 
566             // Handshake.fragment_offset
567             temporary[6] = 0;
568             temporary[7] = 0;
569             temporary[8] = 0;
570 
571             // Handshake.fragment_length
572             temporary[9] = temporary[1];
573             temporary[10] = temporary[2];
574             temporary[11] = temporary[3];
575 
576             handshakeHash.deliver(temporary, 0, 12);
577             handshakeHash.deliver(hsBody, 0, hsBody.length);
578         }
579 
isEmpty()580         boolean isEmpty() {
581             if (!flightIsReady || handshakeMemos.isEmpty() ||
582                     acquireIndex >= handshakeMemos.size()) {
583                 return true;
584             }
585 
586             return false;
587         }
588 
hasAlert()589         boolean hasAlert() {
590             for (RecordMemo memo : handshakeMemos) {
591                 if (memo.contentType == ContentType.ALERT.id) {
592                     return true;
593                 }
594             }
595 
596             return false;
597         }
598 
isRetransmittable()599         boolean isRetransmittable() {
600             return (flightIsReady && !handshakeMemos.isEmpty() &&
601                                 (acquireIndex >= handshakeMemos.size()));
602         }
603 
setRetransmission()604         private void setRetransmission() {
605             acquireIndex = 0;
606             for (RecordMemo memo : handshakeMemos) {
607                 if (memo instanceof HandshakeMemo) {
608                     HandshakeMemo hmemo = (HandshakeMemo)memo;
609                     hmemo.acquireOffset = 0;
610                 }
611             }
612 
613             // Shrink packet size if:
614             // 1. maximum fragment size is allowed, in which case the packet
615             //    size is configured bigger than maxRecordSize;
616             // 2. maximum packet is bigger than 256 bytes;
617             // 3. two times of retransmits have been attempted.
618             if ((packetSize <= maxRecordSize) &&
619                     (packetSize > 256) && ((retransmits--) <= 0)) {
620 
621                 // shrink packet size
622                 shrinkPacketSize();
623                 retransmits = 2;        // attemps of retransmits
624             }
625         }
626 
shrinkPacketSize()627         private void shrinkPacketSize() {
628             packetSize = Math.max(256, packetSize / 2);
629         }
630     }
631 }
632