1 /*
2  * Copyright (c) 1996, 2020, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.  Oracle designates this
8  * particular file as subject to the "Classpath" exception as provided
9  * by Oracle in the LICENSE file that accompanied this code.
10  *
11  * This code is distributed in the hope that it will be useful, but WITHOUT
12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
14  * version 2 for more details (a copy is included in the LICENSE file that
15  * accompanied this code).
16  *
17  * You should have received a copy of the GNU General Public License version
18  * 2 along with this work; if not, write to the Free Software Foundation,
19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20  *
21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22  * or visit www.oracle.com if you need additional information or have any
23  * questions.
24  */
25 
26 package sun.security.ssl;
27 
28 import java.io.IOException;
29 import java.nio.ByteBuffer;
30 import java.security.GeneralSecurityException;
31 import java.util.ArrayList;
32 import javax.crypto.BadPaddingException;
33 import javax.net.ssl.SSLException;
34 import javax.net.ssl.SSLHandshakeException;
35 import javax.net.ssl.SSLProtocolException;
36 import sun.security.ssl.SSLCipher.SSLReadCipher;
37 
38 /**
39  * {@code InputRecord} implementation for {@code SSLEngine}.
40  */
41 final class SSLEngineInputRecord extends InputRecord implements SSLRecord {
42     private boolean formatVerified = false;     // SSLv2 ruled out?
43 
44     // Cache for incomplete handshake messages.
45     private ByteBuffer handshakeBuffer = null;
46 
SSLEngineInputRecord(HandshakeHash handshakeHash)47     SSLEngineInputRecord(HandshakeHash handshakeHash) {
48         super(handshakeHash, SSLReadCipher.nullTlsReadCipher());
49     }
50 
51     @Override
estimateFragmentSize(int packetSize)52     int estimateFragmentSize(int packetSize) {
53         if (packetSize > 0) {
54             return readCipher.estimateFragmentSize(packetSize, headerSize);
55         } else {
56             return Record.maxDataSize;
57         }
58     }
59 
60     @Override
bytesInCompletePacket( ByteBuffer[] srcs, int srcsOffset, int srcsLength)61     int bytesInCompletePacket(
62         ByteBuffer[] srcs, int srcsOffset, int srcsLength) throws IOException {
63 
64         return bytesInCompletePacket(srcs[srcsOffset]);
65     }
66 
bytesInCompletePacket(ByteBuffer packet)67     private int bytesInCompletePacket(ByteBuffer packet) throws SSLException {
68         /*
69          * SSLv2 length field is in bytes 0/1
70          * SSLv3/TLS length field is in bytes 3/4
71          */
72         if (packet.remaining() < 5) {
73             return -1;
74         }
75 
76         int pos = packet.position();
77         byte byteZero = packet.get(pos);
78 
79         int len = 0;
80 
81         /*
82          * If we have already verified previous packets, we can
83          * ignore the verifications steps, and jump right to the
84          * determination.  Otherwise, try one last heuristic to
85          * see if it's SSL/TLS.
86          */
87         if (formatVerified ||
88                 (byteZero == ContentType.HANDSHAKE.id) ||
89                 (byteZero == ContentType.ALERT.id)) {
90             /*
91              * Last sanity check that it's not a wild record
92              */
93             byte majorVersion = packet.get(pos + 1);
94             byte minorVersion = packet.get(pos + 2);
95             if (!ProtocolVersion.isNegotiable(
96                     majorVersion, minorVersion, false, false)) {
97                 throw new SSLException("Unrecognized record version " +
98                         ProtocolVersion.nameOf(majorVersion, minorVersion) +
99                         " , plaintext connection?");
100             }
101 
102             /*
103              * Reasonably sure this is a V3, disable further checks.
104              * We can't do the same in the v2 check below, because
105              * read still needs to parse/handle the v2 clientHello.
106              */
107             formatVerified = true;
108 
109             /*
110              * One of the SSLv3/TLS message types.
111              */
112             len = ((packet.get(pos + 3) & 0xFF) << 8) +
113                    (packet.get(pos + 4) & 0xFF) + headerSize;
114 
115         } else {
116             /*
117              * Must be SSLv2 or something unknown.
118              * Check if it's short (2 bytes) or
119              * long (3) header.
120              *
121              * Internals can warn about unsupported SSLv2
122              */
123             boolean isShort = ((byteZero & 0x80) != 0);
124 
125             if (isShort &&
126                     ((packet.get(pos + 2) == 1) || packet.get(pos + 2) == 4)) {
127 
128                 byte majorVersion = packet.get(pos + 3);
129                 byte minorVersion = packet.get(pos + 4);
130                 if (!ProtocolVersion.isNegotiable(
131                         majorVersion, minorVersion, false, false)) {
132                     throw new SSLException("Unrecognized record version " +
133                             ProtocolVersion.nameOf(majorVersion, minorVersion) +
134                             " , plaintext connection?");
135                 }
136 
137                 /*
138                  * Client or Server Hello
139                  */
140                 int mask = (isShort ? 0x7F : 0x3F);
141                 len = ((byteZero & mask) << 8) +
142                         (packet.get(pos + 1) & 0xFF) + (isShort ? 2 : 3);
143 
144             } else {
145                 // Gobblygook!
146                 throw new SSLException(
147                         "Unrecognized SSL message, plaintext connection?");
148             }
149         }
150 
151         return len;
152     }
153 
154     @Override
decode(ByteBuffer[] srcs, int srcsOffset, int srcsLength)155     Plaintext[] decode(ByteBuffer[] srcs, int srcsOffset,
156             int srcsLength) throws IOException, BadPaddingException {
157         if (srcs == null || srcs.length == 0 || srcsLength == 0) {
158             return new Plaintext[0];
159         } else if (srcsLength == 1) {
160             return decode(srcs[srcsOffset]);
161         } else {
162             ByteBuffer packet = extract(srcs,
163                     srcsOffset, srcsLength, SSLRecord.headerSize);
164 
165             return decode(packet);
166         }
167     }
168 
decode(ByteBuffer packet)169     private Plaintext[] decode(ByteBuffer packet)
170             throws IOException, BadPaddingException {
171 
172         if (isClosed) {
173             return null;
174         }
175 
176         if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
177             SSLLogger.fine("Raw read", packet);
178         }
179 
180         // The caller should have validated the record.
181         if (!formatVerified) {
182             formatVerified = true;
183 
184             /*
185              * The first record must either be a handshake record or an
186              * alert message. If it's not, it is either invalid or an
187              * SSLv2 message.
188              */
189             int pos = packet.position();
190             byte byteZero = packet.get(pos);
191             if (byteZero != ContentType.HANDSHAKE.id &&
192                     byteZero != ContentType.ALERT.id) {
193                 return handleUnknownRecord(packet);
194             }
195         }
196 
197         return decodeInputRecord(packet);
198     }
199 
decodeInputRecord(ByteBuffer packet)200     private Plaintext[] decodeInputRecord(ByteBuffer packet)
201             throws IOException, BadPaddingException {
202         //
203         // The packet should be a complete record, or more.
204         //
205         int srcPos = packet.position();
206         int srcLim = packet.limit();
207 
208         byte contentType = packet.get();                   // pos: 0
209         byte majorVersion = packet.get();                  // pos: 1
210         byte minorVersion = packet.get();                  // pos: 2
211         int contentLen = Record.getInt16(packet);          // pos: 3, 4
212 
213         if (SSLLogger.isOn && SSLLogger.isOn("record")) {
214             SSLLogger.fine(
215                     "READ: " +
216                     ProtocolVersion.nameOf(majorVersion, minorVersion) +
217                     " " + ContentType.nameOf(contentType) + ", length = " +
218                     contentLen);
219         }
220 
221         //
222         // Check for upper bound.
223         //
224         // Note: May check packetSize limit in the future.
225         if (contentLen < 0 || contentLen > maxLargeRecordSize - headerSize) {
226             throw new SSLProtocolException(
227                 "Bad input record size, TLSCiphertext.length = " + contentLen);
228         }
229 
230         //
231         // Decrypt the fragment
232         //
233         int recLim = srcPos + SSLRecord.headerSize + contentLen;
234         packet.limit(recLim);
235         packet.position(srcPos + SSLRecord.headerSize);
236 
237         ByteBuffer fragment;
238         try {
239             Plaintext plaintext =
240                     readCipher.decrypt(contentType, packet, null);
241             fragment = plaintext.fragment;
242             contentType = plaintext.contentType;
243         } catch (BadPaddingException bpe) {
244             throw bpe;
245         } catch (GeneralSecurityException gse) {
246             throw (SSLProtocolException)(new SSLProtocolException(
247                     "Unexpected exception")).initCause(gse);
248         } finally {
249             // consume a complete record
250             packet.limit(srcLim);
251             packet.position(recLim);
252         }
253 
254         //
255         // check for handshake fragment
256         //
257         if (contentType != ContentType.HANDSHAKE.id &&
258                 handshakeBuffer != null && handshakeBuffer.hasRemaining()) {
259             throw new SSLProtocolException(
260                     "Expecting a handshake fragment, but received " +
261                     ContentType.nameOf(contentType));
262         }
263 
264         //
265         // parse handshake messages
266         //
267         if (contentType == ContentType.HANDSHAKE.id) {
268             ByteBuffer handshakeFrag = fragment;
269             if ((handshakeBuffer != null) &&
270                     (handshakeBuffer.remaining() != 0)) {
271                 ByteBuffer bb = ByteBuffer.wrap(new byte[
272                         handshakeBuffer.remaining() + fragment.remaining()]);
273                 bb.put(handshakeBuffer);
274                 bb.put(fragment);
275                 handshakeFrag = bb.rewind();
276                 handshakeBuffer = null;
277             }
278 
279             ArrayList<Plaintext> plaintexts = new ArrayList<>(5);
280             while (handshakeFrag.hasRemaining()) {
281                 int remaining = handshakeFrag.remaining();
282                 if (remaining < handshakeHeaderSize) {
283                     handshakeBuffer = ByteBuffer.wrap(new byte[remaining]);
284                     handshakeBuffer.put(handshakeFrag);
285                     handshakeBuffer.rewind();
286                     break;
287                 }
288 
289                 handshakeFrag.mark();
290 
291                 // Fail fast for unknown handshake message.
292                 byte handshakeType = handshakeFrag.get();
293                 if (!SSLHandshake.isKnown(handshakeType)) {
294                     throw new SSLProtocolException(
295                         "Unknown handshake type size, Handshake.msg_type = " +
296                         (handshakeType & 0xFF));
297                 }
298 
299                 int handshakeBodyLen = Record.getInt24(handshakeFrag);
300                 if (handshakeBodyLen > SSLConfiguration.maxHandshakeMessageSize) {
301                     throw new SSLProtocolException(
302                             "The size of the handshake message ("
303                             + handshakeBodyLen
304                             + ") exceeds the maximum allowed size ("
305                             + SSLConfiguration.maxHandshakeMessageSize
306                             + ")");
307                 }
308 
309                 handshakeFrag.reset();
310                 int handshakeMessageLen =
311                         handshakeHeaderSize + handshakeBodyLen;
312                 if (remaining < handshakeMessageLen) {
313                     handshakeBuffer = ByteBuffer.wrap(new byte[remaining]);
314                     handshakeBuffer.put(handshakeFrag);
315                     handshakeBuffer.rewind();
316                     break;
317                 } else if (remaining == handshakeMessageLen) {
318                     if (handshakeHash.isHashable(handshakeType)) {
319                         handshakeHash.receive(handshakeFrag);
320                     }
321 
322                     plaintexts.add(
323                         new Plaintext(contentType,
324                             majorVersion, minorVersion, -1, -1L, handshakeFrag)
325                     );
326                     break;
327                 } else {
328                     int fragPos = handshakeFrag.position();
329                     int fragLim = handshakeFrag.limit();
330                     int nextPos = fragPos + handshakeMessageLen;
331                     handshakeFrag.limit(nextPos);
332 
333                     if (handshakeHash.isHashable(handshakeType)) {
334                         handshakeHash.receive(handshakeFrag);
335                     }
336 
337                     plaintexts.add(
338                         new Plaintext(contentType, majorVersion, minorVersion,
339                             -1, -1L, handshakeFrag.slice())
340                     );
341 
342                     handshakeFrag.position(nextPos);
343                     handshakeFrag.limit(fragLim);
344                 }
345             }
346 
347             return plaintexts.toArray(new Plaintext[0]);
348         }
349 
350         return new Plaintext[] {
351             new Plaintext(contentType,
352                 majorVersion, minorVersion, -1, -1L, fragment)
353         };
354     }
355 
handleUnknownRecord(ByteBuffer packet)356     private Plaintext[] handleUnknownRecord(ByteBuffer packet)
357             throws IOException, BadPaddingException {
358         //
359         // The packet should be a complete record.
360         //
361         int srcPos = packet.position();
362         int srcLim = packet.limit();
363 
364         byte firstByte = packet.get(srcPos);
365         byte thirdByte = packet.get(srcPos + 2);
366 
367         // Does it look like a Version 2 client hello (V2ClientHello)?
368         if (((firstByte & 0x80) != 0) && (thirdByte == 1)) {
369             /*
370              * If SSLv2Hello is not enabled, throw an exception.
371              */
372             if (helloVersion != ProtocolVersion.SSL20Hello) {
373                 throw new SSLHandshakeException("SSLv2Hello is not enabled");
374             }
375 
376             byte majorVersion = packet.get(srcPos + 3);
377             byte minorVersion = packet.get(srcPos + 4);
378 
379             if ((majorVersion == ProtocolVersion.SSL20Hello.major) &&
380                 (minorVersion == ProtocolVersion.SSL20Hello.minor)) {
381 
382                 /*
383                  * Looks like a V2 client hello, but not one saying
384                  * "let's talk SSLv3".  So we need to send an SSLv2
385                  * error message, one that's treated as fatal by
386                  * clients (Otherwise we'll hang.)
387                  */
388                 if (SSLLogger.isOn && SSLLogger.isOn("record")) {
389                    SSLLogger.fine(
390                             "Requested to negotiate unsupported SSLv2!");
391                 }
392 
393                 // Note that the exception is caught in SSLEngineImpl
394                 // so that SSLv2 error message can be delivered properly.
395                 throw new UnsupportedOperationException(        // SSLv2Hello
396                         "Unsupported SSL v2.0 ClientHello");
397             }
398 
399             /*
400              * If we can map this into a V3 ClientHello, read and
401              * hash the rest of the V2 handshake, turn it into a
402              * V3 ClientHello message, and pass it up.
403              */
404             packet.position(srcPos + 2);        // exclude the header
405             handshakeHash.receive(packet);
406             packet.position(srcPos);
407 
408             ByteBuffer converted = convertToClientHello(packet);
409 
410             if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
411                 SSLLogger.fine(
412                         "[Converted] ClientHello", converted);
413             }
414 
415             return new Plaintext[] {
416                     new Plaintext(ContentType.HANDSHAKE.id,
417                     majorVersion, minorVersion, -1, -1L, converted)
418                 };
419         } else {
420             if (((firstByte & 0x80) != 0) && (thirdByte == 4)) {
421                 throw new SSLException("SSL V2.0 servers are not supported.");
422             }
423 
424             throw new SSLException("Unsupported or unrecognized SSL message");
425         }
426     }
427 }
428