1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 package nginx.unit.websocket.server;
18 
19 import java.io.IOException;
20 import java.nio.charset.StandardCharsets;
21 import java.util.ArrayList;
22 import java.util.Collections;
23 import java.util.Enumeration;
24 import java.util.LinkedHashMap;
25 import java.util.List;
26 import java.util.Map;
27 import java.util.Map.Entry;
28 
29 import javax.servlet.ServletException;
30 import javax.servlet.ServletRequest;
31 import javax.servlet.ServletResponse;
32 import javax.servlet.http.HttpServletRequest;
33 import javax.servlet.http.HttpServletResponse;
34 import javax.websocket.Endpoint;
35 import javax.websocket.Extension;
36 import javax.websocket.HandshakeResponse;
37 import javax.websocket.server.ServerEndpointConfig;
38 
39 import nginx.unit.Request;
40 
41 import org.apache.tomcat.util.codec.binary.Base64;
42 import org.apache.tomcat.util.res.StringManager;
43 import org.apache.tomcat.util.security.ConcurrentMessageDigest;
44 import nginx.unit.websocket.Constants;
45 import nginx.unit.websocket.Transformation;
46 import nginx.unit.websocket.TransformationFactory;
47 import nginx.unit.websocket.Util;
48 import nginx.unit.websocket.WsHandshakeResponse;
49 import nginx.unit.websocket.pojo.PojoEndpointServer;
50 
51 public class UpgradeUtil {
52 
53     private static final StringManager sm =
54             StringManager.getManager(UpgradeUtil.class.getPackage().getName());
55     private static final byte[] WS_ACCEPT =
56             "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes(
57                     StandardCharsets.ISO_8859_1);
58 
UpgradeUtil()59     private UpgradeUtil() {
60         // Utility class. Hide default constructor.
61     }
62 
63     /**
64      * Checks to see if this is an HTTP request that includes a valid upgrade
65      * request to web socket.
66      * <p>
67      * Note: RFC 2616 does not limit HTTP upgrade to GET requests but the Java
68      *       WebSocket spec 1.0, section 8.2 implies such a limitation and RFC
69      *       6455 section 4.1 requires that a WebSocket Upgrade uses GET.
70      * @param request  The request to check if it is an HTTP upgrade request for
71      *                 a WebSocket connection
72      * @param response The response associated with the request
73      * @return <code>true</code> if the request includes a HTTP Upgrade request
74      *         for the WebSocket protocol, otherwise <code>false</code>
75      */
isWebSocketUpgradeRequest(ServletRequest request, ServletResponse response)76     public static boolean isWebSocketUpgradeRequest(ServletRequest request,
77             ServletResponse response) {
78 
79         Request r = (Request) request.getAttribute(Request.BARE);
80 
81         return ((request instanceof HttpServletRequest) &&
82                 (response instanceof HttpServletResponse) &&
83                 (r != null) &&
84                 (r.isUpgrade()));
85     }
86 
87 
doUpgrade(WsServerContainer sc, HttpServletRequest req, HttpServletResponse resp, ServerEndpointConfig sec, Map<String,String> pathParams)88     public static void doUpgrade(WsServerContainer sc, HttpServletRequest req,
89             HttpServletResponse resp, ServerEndpointConfig sec,
90             Map<String,String> pathParams)
91             throws ServletException, IOException {
92 
93 
94         // Origin check
95         String origin = req.getHeader(Constants.ORIGIN_HEADER_NAME);
96 
97         if (!sec.getConfigurator().checkOrigin(origin)) {
98             resp.sendError(HttpServletResponse.SC_FORBIDDEN);
99             return;
100         }
101         // Sub-protocols
102         List<String> subProtocols = getTokensFromHeader(req,
103                 Constants.WS_PROTOCOL_HEADER_NAME);
104         String subProtocol = sec.getConfigurator().getNegotiatedSubprotocol(
105                 sec.getSubprotocols(), subProtocols);
106 
107         // Extensions
108         // Should normally only be one header but handle the case of multiple
109         // headers
110         List<Extension> extensionsRequested = new ArrayList<>();
111         Enumeration<String> extHeaders = req.getHeaders(Constants.WS_EXTENSIONS_HEADER_NAME);
112         while (extHeaders.hasMoreElements()) {
113             Util.parseExtensionHeader(extensionsRequested, extHeaders.nextElement());
114         }
115 
116         // Negotiation phase 1. By default this simply filters out the
117         // extensions that the server does not support but applications could
118         // use a custom configurator to do more than this.
119         List<Extension> installedExtensions = null;
120         if (sec.getExtensions().size() == 0) {
121             installedExtensions = Constants.INSTALLED_EXTENSIONS;
122         } else {
123             installedExtensions = new ArrayList<>();
124             installedExtensions.addAll(sec.getExtensions());
125             installedExtensions.addAll(Constants.INSTALLED_EXTENSIONS);
126         }
127         List<Extension> negotiatedExtensionsPhase1 = sec.getConfigurator().getNegotiatedExtensions(
128                 installedExtensions, extensionsRequested);
129 
130         // Negotiation phase 2. Create the Transformations that will be applied
131         // to this connection. Note than an extension may be dropped at this
132         // point if the client has requested a configuration that the server is
133         // unable to support.
134         List<Transformation> transformations = createTransformations(negotiatedExtensionsPhase1);
135 
136         List<Extension> negotiatedExtensionsPhase2;
137         if (transformations.isEmpty()) {
138             negotiatedExtensionsPhase2 = Collections.emptyList();
139         } else {
140             negotiatedExtensionsPhase2 = new ArrayList<>(transformations.size());
141             for (Transformation t : transformations) {
142                 negotiatedExtensionsPhase2.add(t.getExtensionResponse());
143             }
144         }
145 
146         WsHttpUpgradeHandler wsHandler =
147                 req.upgrade(WsHttpUpgradeHandler.class);
148 
149         WsHandshakeRequest wsRequest = new WsHandshakeRequest(req, pathParams);
150         WsHandshakeResponse wsResponse = new WsHandshakeResponse();
151         WsPerSessionServerEndpointConfig perSessionServerEndpointConfig =
152                 new WsPerSessionServerEndpointConfig(sec);
153         sec.getConfigurator().modifyHandshake(perSessionServerEndpointConfig,
154                 wsRequest, wsResponse);
155         //wsRequest.finished();
156 
157         // Add any additional headers
158         for (Entry<String,List<String>> entry :
159                 wsResponse.getHeaders().entrySet()) {
160             for (String headerValue: entry.getValue()) {
161                 resp.addHeader(entry.getKey(), headerValue);
162             }
163         }
164 
165         Endpoint ep;
166         try {
167             Class<?> clazz = sec.getEndpointClass();
168             if (Endpoint.class.isAssignableFrom(clazz)) {
169                 ep = (Endpoint) sec.getConfigurator().getEndpointInstance(
170                         clazz);
171             } else {
172                 ep = new PojoEndpointServer();
173                 // Need to make path params available to POJO
174                 perSessionServerEndpointConfig.getUserProperties().put(
175                         nginx.unit.websocket.pojo.Constants.POJO_PATH_PARAM_KEY, pathParams);
176             }
177         } catch (InstantiationException e) {
178             throw new ServletException(e);
179         }
180 
181         wsHandler.preInit(ep, perSessionServerEndpointConfig, sc, wsRequest,
182                 negotiatedExtensionsPhase2, subProtocol, null, pathParams,
183                 req.isSecure());
184 
185         wsHandler.init(null);
186     }
187 
188 
createTransformations( List<Extension> negotiatedExtensions)189     private static List<Transformation> createTransformations(
190             List<Extension> negotiatedExtensions) {
191 
192         TransformationFactory factory = TransformationFactory.getInstance();
193 
194         LinkedHashMap<String,List<List<Extension.Parameter>>> extensionPreferences =
195                 new LinkedHashMap<>();
196 
197         // Result will likely be smaller than this
198         List<Transformation> result = new ArrayList<>(negotiatedExtensions.size());
199 
200         for (Extension extension : negotiatedExtensions) {
201             List<List<Extension.Parameter>> preferences =
202                     extensionPreferences.get(extension.getName());
203 
204             if (preferences == null) {
205                 preferences = new ArrayList<>();
206                 extensionPreferences.put(extension.getName(), preferences);
207             }
208 
209             preferences.add(extension.getParameters());
210         }
211 
212         for (Map.Entry<String,List<List<Extension.Parameter>>> entry :
213             extensionPreferences.entrySet()) {
214             Transformation transformation = factory.create(entry.getKey(), entry.getValue(), true);
215             if (transformation != null) {
216                 result.add(transformation);
217             }
218         }
219         return result;
220     }
221 
222 
append(StringBuilder sb, Extension extension)223     private static void append(StringBuilder sb, Extension extension) {
224         if (extension == null || extension.getName() == null || extension.getName().length() == 0) {
225             return;
226         }
227 
228         sb.append(extension.getName());
229 
230         for (Extension.Parameter p : extension.getParameters()) {
231             sb.append(';');
232             sb.append(p.getName());
233             if (p.getValue() != null) {
234                 sb.append('=');
235                 sb.append(p.getValue());
236             }
237         }
238     }
239 
240 
241     /*
242      * This only works for tokens. Quoted strings need more sophisticated
243      * parsing.
244      */
headerContainsToken(HttpServletRequest req, String headerName, String target)245     private static boolean headerContainsToken(HttpServletRequest req,
246             String headerName, String target) {
247         Enumeration<String> headers = req.getHeaders(headerName);
248         while (headers.hasMoreElements()) {
249             String header = headers.nextElement();
250             String[] tokens = header.split(",");
251             for (String token : tokens) {
252                 if (target.equalsIgnoreCase(token.trim())) {
253                     return true;
254                 }
255             }
256         }
257         return false;
258     }
259 
260 
261     /*
262      * This only works for tokens. Quoted strings need more sophisticated
263      * parsing.
264      */
getTokensFromHeader(HttpServletRequest req, String headerName)265     private static List<String> getTokensFromHeader(HttpServletRequest req,
266             String headerName) {
267         List<String> result = new ArrayList<>();
268         Enumeration<String> headers = req.getHeaders(headerName);
269         while (headers.hasMoreElements()) {
270             String header = headers.nextElement();
271             String[] tokens = header.split(",");
272             for (String token : tokens) {
273                 result.add(token.trim());
274             }
275         }
276         return result;
277     }
278 
279 
getWebSocketAccept(String key)280     private static String getWebSocketAccept(String key) {
281         byte[] digest = ConcurrentMessageDigest.digestSHA1(
282                 key.getBytes(StandardCharsets.ISO_8859_1), WS_ACCEPT);
283         return Base64.encodeBase64String(digest);
284     }
285 }
286