1 /*
2  * Copyright (c) 2018, 2020 Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.  Oracle designates this
8  * particular file as subject to the "Classpath" exception as provided
9  * by Oracle in the LICENSE file that accompanied this code.
10  *
11  * This code is distributed in the hope that it will be useful, but WITHOUT
12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
14  * version 2 for more details (a copy is included in the LICENSE file that
15  * accompanied this code).
16  *
17  * You should have received a copy of the GNU General Public License version
18  * 2 along with this work; if not, write to the Free Software Foundation,
19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20  *
21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22  * or visit www.oracle.com if you need additional information or have any
23  * questions.
24  */
25 
26 package sun.security.ssl;
27 
28 import java.io.IOException;
29 import java.nio.ByteBuffer;
30 import java.text.MessageFormat;
31 import java.util.*;
32 
33 import sun.security.ssl.SSLHandshake.HandshakeMessage;
34 import sun.security.util.HexDumpEncoder;
35 
36 /**
37  * SSL/(D)TLS extensions in a handshake message.
38  */
39 final class SSLExtensions {
40     private final HandshakeMessage handshakeMessage;
41     private Map<SSLExtension, byte[]> extMap = new LinkedHashMap<>();
42     private int encodedLength;
43 
44     // Extension map for debug logging
45     private final Map<Integer, byte[]> logMap =
46             SSLLogger.isOn ? new LinkedHashMap<>() : null;
47 
SSLExtensions(HandshakeMessage handshakeMessage)48     SSLExtensions(HandshakeMessage handshakeMessage) {
49         this.handshakeMessage = handshakeMessage;
50         this.encodedLength = 2;         // 2: the length of the extensions.
51     }
52 
SSLExtensions(HandshakeMessage hm, ByteBuffer m, SSLExtension[] extensions)53     SSLExtensions(HandshakeMessage hm,
54             ByteBuffer m, SSLExtension[] extensions) throws IOException {
55         this.handshakeMessage = hm;
56 
57         int len = Record.getInt16(m);
58         encodedLength = len + 2;        // 2: the length of the extensions.
59         while (len > 0) {
60             int extId = Record.getInt16(m);
61             int extLen = Record.getInt16(m);
62             if (extLen > m.remaining()) {
63                 throw hm.handshakeContext.conContext.fatal(
64                         Alert.ILLEGAL_PARAMETER,
65                         "Error parsing extension (" + extId +
66                         "): no sufficient data");
67             }
68 
69             boolean isSupported = true;
70             SSLHandshake handshakeType = hm.handshakeType();
71             if (SSLExtension.isConsumable(extId) &&
72                     SSLExtension.valueOf(handshakeType, extId) == null) {
73                 if (extId == SSLExtension.CH_SUPPORTED_GROUPS.id &&
74                         handshakeType == SSLHandshake.SERVER_HELLO) {
75                     // Note: It does not comply to the specification.  However,
76                     // there are servers that send the supported_groups
77                     // extension in ServerHello handshake message.
78                     //
79                     // TLS 1.3 should not send this extension.   We may want to
80                     // limit the workaround for TLS 1.2 and prior version only.
81                     // However, the implementation of the limit is complicated
82                     // and inefficient, and may not worthy the maintenance.
83                     isSupported = false;
84                     if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
85                         SSLLogger.warning(
86                                 "Received buggy supported_groups extension " +
87                                 "in the ServerHello handshake message");
88                     }
89                 } else if (handshakeType == SSLHandshake.SERVER_HELLO) {
90                     throw hm.handshakeContext.conContext.fatal(
91                             Alert.UNSUPPORTED_EXTENSION, "extension (" +
92                                     extId + ") should not be presented in " +
93                                     handshakeType.name);
94                 } else {
95                     isSupported = false;
96                     // debug log to ignore unknown extension for handshakeType
97                 }
98             }
99 
100             if (isSupported) {
101                 isSupported = false;
102                 for (SSLExtension extension : extensions) {
103                     if ((extension.id != extId) ||
104                             (extension.onLoadConsumer == null)) {
105                         continue;
106                     }
107 
108                     if (extension.handshakeType != handshakeType) {
109                         throw hm.handshakeContext.conContext.fatal(
110                                 Alert.UNSUPPORTED_EXTENSION,
111                                 "extension (" + extId + ") should not be " +
112                                 "presented in " + handshakeType.name);
113                     }
114 
115                     byte[] extData = new byte[extLen];
116                     m.get(extData);
117                     extMap.put(extension, extData);
118                     if (logMap != null) {
119                         logMap.put(extId, extData);
120                     }
121 
122                     isSupported = true;
123                     break;
124                 }
125             }
126 
127             if (!isSupported) {
128                 if (logMap != null) {
129                     // cache the extension for debug logging
130                     byte[] extData = new byte[extLen];
131                     m.get(extData);
132                     logMap.put(extId, extData);
133 
134                     if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
135                         SSLLogger.fine(
136                                 "Ignore unknown or unsupported extension",
137                                 toString(extId, extData));
138                     }
139                 } else {
140                     // ignore the extension
141                     int pos = m.position() + extLen;
142                     m.position(pos);
143                 }
144             }
145 
146             len -= extLen + 4;
147         }
148     }
149 
get(SSLExtension ext)150     byte[] get(SSLExtension ext) {
151         return extMap.get(ext);
152     }
153 
154     /**
155      * Consume the specified extensions.
156      */
consumeOnLoad(HandshakeContext context, SSLExtension[] extensions)157     void consumeOnLoad(HandshakeContext context,
158             SSLExtension[] extensions) throws IOException {
159         for (SSLExtension extension : extensions) {
160             if (context.negotiatedProtocol != null &&
161                     !extension.isAvailable(context.negotiatedProtocol)) {
162                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
163                     SSLLogger.fine(
164                         "Ignore unsupported extension: " + extension.name);
165                 }
166                 continue;
167             }
168 
169             if (!extMap.containsKey(extension)) {
170                 if (extension.onLoadAbsence != null) {
171                     extension.absentOnLoad(context, handshakeMessage);
172                 } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
173                     SSLLogger.fine(
174                         "Ignore unavailable extension: " + extension.name);
175                 }
176                 continue;
177             }
178 
179 
180             if (extension.onLoadConsumer == null) {
181                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
182                     SSLLogger.warning(
183                         "Ignore unsupported extension: " + extension.name);
184                 }
185                 continue;
186             }
187 
188             ByteBuffer m = ByteBuffer.wrap(extMap.get(extension));
189             extension.consumeOnLoad(context, handshakeMessage, m);
190 
191             if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
192                 SSLLogger.fine("Consumed extension: " + extension.name);
193             }
194         }
195     }
196 
197     /**
198      * Consider impact of the specified extensions.
199      */
consumeOnTrade(HandshakeContext context, SSLExtension[] extensions)200     void consumeOnTrade(HandshakeContext context,
201             SSLExtension[] extensions) throws IOException {
202         for (SSLExtension extension : extensions) {
203             if (!extMap.containsKey(extension)) {
204                 if (extension.onTradeAbsence != null) {
205                     extension.absentOnTrade(context, handshakeMessage);
206                 } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
207                     SSLLogger.fine(
208                         "Ignore unavailable extension: " + extension.name);
209                 }
210                 continue;
211             }
212 
213             if (extension.onTradeConsumer == null) {
214                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
215                     SSLLogger.warning(
216                             "Ignore impact of unsupported extension: " +
217                             extension.name);
218                 }
219                 continue;
220             }
221 
222             extension.consumeOnTrade(context, handshakeMessage);
223             if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
224                 SSLLogger.fine("Populated with extension: " + extension.name);
225             }
226         }
227     }
228 
229     /**
230      * Produce extension values for the specified extensions.
231      */
produce(HandshakeContext context, SSLExtension[] extensions)232     void produce(HandshakeContext context,
233             SSLExtension[] extensions) throws IOException {
234         for (SSLExtension extension : extensions) {
235             if (extMap.containsKey(extension)) {
236                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
237                     SSLLogger.fine(
238                             "Ignore, duplicated extension: " +
239                             extension.name);
240                 }
241                 continue;
242             }
243 
244             if (extension.networkProducer == null) {
245                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
246                     SSLLogger.warning(
247                             "Ignore, no extension producer defined: " +
248                             extension.name);
249                 }
250                 continue;
251             }
252 
253             byte[] encoded = extension.produce(context, handshakeMessage);
254             if (encoded != null) {
255                 extMap.put(extension, encoded);
256                 encodedLength += encoded.length + 4; // extension_type (2)
257                                                      // extension_data length(2)
258             } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
259                 // The extension is not available in the context.
260                 SSLLogger.fine(
261                         "Ignore, context unavailable extension: " +
262                         extension.name);
263             }
264         }
265     }
266 
267     /**
268      * Produce extension values for the specified extensions, replacing if
269      * there is an existing extension value for a specified extension.
270      */
reproduce(HandshakeContext context, SSLExtension[] extensions)271     void reproduce(HandshakeContext context,
272             SSLExtension[] extensions) throws IOException {
273         for (SSLExtension extension : extensions) {
274             if (extension.networkProducer == null) {
275                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
276                     SSLLogger.warning(
277                             "Ignore, no extension producer defined: " +
278                             extension.name);
279                 }
280                 continue;
281             }
282 
283             byte[] encoded = extension.produce(context, handshakeMessage);
284             if (encoded != null) {
285                 if (extMap.containsKey(extension)) {
286                     byte[] old = extMap.replace(extension, encoded);
287                     if (old != null) {
288                         encodedLength -= old.length + 4;
289                     }
290                     encodedLength += encoded.length + 4;
291                 } else {
292                     extMap.put(extension, encoded);
293                     encodedLength += encoded.length + 4;
294                                                     // extension_type (2)
295                                                     // extension_data length(2)
296                 }
297             } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
298                 // The extension is not available in the context.
299                 SSLLogger.fine(
300                         "Ignore, context unavailable extension: " +
301                         extension.name);
302             }
303         }
304     }
305 
306     // Note that TLS 1.3 may use empty extensions.  Please consider it while
307     // using this method.
length()308     int length() {
309         if (extMap.isEmpty()) {
310             return 0;
311         } else {
312             return encodedLength;
313         }
314     }
315 
316     // Note that TLS 1.3 may use empty extensions.  Please consider it while
317     // using this method.
send(HandshakeOutStream hos)318     void send(HandshakeOutStream hos) throws IOException {
319         int extsLen = length();
320         if (extsLen == 0) {
321             return;
322         }
323         hos.putInt16(extsLen - 2);
324         // extensions must be sent in the order they appear in the enum
325         for (SSLExtension ext : SSLExtension.values()) {
326             byte[] extData = extMap.get(ext);
327             if (extData != null) {
328                 hos.putInt16(ext.id);
329                 hos.putBytes16(extData);
330             }
331         }
332     }
333 
334     @Override
toString()335     public String toString() {
336         if (extMap.isEmpty() && (logMap == null || logMap.isEmpty())) {
337             return "<no extension>";
338         } else {
339             StringBuilder builder = new StringBuilder(512);
340             if (logMap != null && !logMap.isEmpty()) {
341                 for (Map.Entry<Integer, byte[]> en : logMap.entrySet()) {
342                     SSLExtension ext = SSLExtension.valueOf(
343                             handshakeMessage.handshakeType(), en.getKey());
344                     if (builder.length() != 0) {
345                         builder.append(",\n");
346                     }
347                     if (ext != null) {
348                         builder.append(
349                                 ext.toString(ByteBuffer.wrap(en.getValue())));
350                     } else {
351                         builder.append(toString(en.getKey(), en.getValue()));
352                     }
353                 }
354 
355                 return builder.toString();
356             } else {
357                 for (Map.Entry<SSLExtension, byte[]> en : extMap.entrySet()) {
358                     if (builder.length() != 0) {
359                         builder.append(",\n");
360                     }
361                     builder.append(
362                         en.getKey().toString(ByteBuffer.wrap(en.getValue())));
363                 }
364 
365                 return builder.toString();
366             }
367         }
368     }
369 
toString(int extId, byte[] extData)370     private static String toString(int extId, byte[] extData) {
371         String extName = SSLExtension.nameOf(extId);
372         MessageFormat messageFormat = new MessageFormat(
373             "\"{0} ({1})\": '{'\n" +
374             "{2}\n" +
375             "'}'",
376             Locale.ENGLISH);
377 
378         HexDumpEncoder hexEncoder = new HexDumpEncoder();
379         String encoded = hexEncoder.encodeBuffer(extData);
380 
381         Object[] messageFields = {
382             extName,
383             extId,
384             Utilities.indent(encoded)
385         };
386 
387         return messageFormat.format(messageFields);
388     }
389 }
390