1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 package nginx.unit.websocket.server; 18 19 import java.io.IOException; 20 import java.nio.charset.StandardCharsets; 21 import java.util.ArrayList; 22 import java.util.Collections; 23 import java.util.Enumeration; 24 import java.util.LinkedHashMap; 25 import java.util.List; 26 import java.util.Map; 27 import java.util.Map.Entry; 28 29 import javax.servlet.ServletException; 30 import javax.servlet.ServletRequest; 31 import javax.servlet.ServletResponse; 32 import javax.servlet.http.HttpServletRequest; 33 import javax.servlet.http.HttpServletResponse; 34 import javax.websocket.Endpoint; 35 import javax.websocket.Extension; 36 import javax.websocket.HandshakeResponse; 37 import javax.websocket.server.ServerEndpointConfig; 38 39 import nginx.unit.Request; 40 41 import org.apache.tomcat.util.codec.binary.Base64; 42 import org.apache.tomcat.util.res.StringManager; 43 import org.apache.tomcat.util.security.ConcurrentMessageDigest; 44 import nginx.unit.websocket.Constants; 45 import nginx.unit.websocket.Transformation; 46 import nginx.unit.websocket.TransformationFactory; 47 import nginx.unit.websocket.Util; 48 import nginx.unit.websocket.WsHandshakeResponse; 49 import nginx.unit.websocket.pojo.PojoEndpointServer; 50 51 public class UpgradeUtil { 52 53 private static final StringManager sm = 54 StringManager.getManager(UpgradeUtil.class.getPackage().getName()); 55 private static final byte[] WS_ACCEPT = 56 "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes( 57 StandardCharsets.ISO_8859_1); 58 UpgradeUtil()59 private UpgradeUtil() { 60 // Utility class. Hide default constructor. 61 } 62 63 /** 64 * Checks to see if this is an HTTP request that includes a valid upgrade 65 * request to web socket. 66 * <p> 67 * Note: RFC 2616 does not limit HTTP upgrade to GET requests but the Java 68 * WebSocket spec 1.0, section 8.2 implies such a limitation and RFC 69 * 6455 section 4.1 requires that a WebSocket Upgrade uses GET. 70 * @param request The request to check if it is an HTTP upgrade request for 71 * a WebSocket connection 72 * @param response The response associated with the request 73 * @return <code>true</code> if the request includes a HTTP Upgrade request 74 * for the WebSocket protocol, otherwise <code>false</code> 75 */ isWebSocketUpgradeRequest(ServletRequest request, ServletResponse response)76 public static boolean isWebSocketUpgradeRequest(ServletRequest request, 77 ServletResponse response) { 78 79 Request r = (Request) request.getAttribute(Request.BARE); 80 81 return ((request instanceof HttpServletRequest) && 82 (response instanceof HttpServletResponse) && 83 (r != null) && 84 (r.isUpgrade())); 85 } 86 87 doUpgrade(WsServerContainer sc, HttpServletRequest req, HttpServletResponse resp, ServerEndpointConfig sec, Map<String,String> pathParams)88 public static void doUpgrade(WsServerContainer sc, HttpServletRequest req, 89 HttpServletResponse resp, ServerEndpointConfig sec, 90 Map<String,String> pathParams) 91 throws ServletException, IOException { 92 93 94 // Origin check 95 String origin = req.getHeader(Constants.ORIGIN_HEADER_NAME); 96 97 if (!sec.getConfigurator().checkOrigin(origin)) { 98 resp.sendError(HttpServletResponse.SC_FORBIDDEN); 99 return; 100 } 101 // Sub-protocols 102 List<String> subProtocols = getTokensFromHeader(req, 103 Constants.WS_PROTOCOL_HEADER_NAME); 104 String subProtocol = sec.getConfigurator().getNegotiatedSubprotocol( 105 sec.getSubprotocols(), subProtocols); 106 107 // Extensions 108 // Should normally only be one header but handle the case of multiple 109 // headers 110 List<Extension> extensionsRequested = new ArrayList<>(); 111 Enumeration<String> extHeaders = req.getHeaders(Constants.WS_EXTENSIONS_HEADER_NAME); 112 while (extHeaders.hasMoreElements()) { 113 Util.parseExtensionHeader(extensionsRequested, extHeaders.nextElement()); 114 } 115 116 // Negotiation phase 1. By default this simply filters out the 117 // extensions that the server does not support but applications could 118 // use a custom configurator to do more than this. 119 List<Extension> installedExtensions = null; 120 if (sec.getExtensions().size() == 0) { 121 installedExtensions = Constants.INSTALLED_EXTENSIONS; 122 } else { 123 installedExtensions = new ArrayList<>(); 124 installedExtensions.addAll(sec.getExtensions()); 125 installedExtensions.addAll(Constants.INSTALLED_EXTENSIONS); 126 } 127 List<Extension> negotiatedExtensionsPhase1 = sec.getConfigurator().getNegotiatedExtensions( 128 installedExtensions, extensionsRequested); 129 130 // Negotiation phase 2. Create the Transformations that will be applied 131 // to this connection. Note than an extension may be dropped at this 132 // point if the client has requested a configuration that the server is 133 // unable to support. 134 List<Transformation> transformations = createTransformations(negotiatedExtensionsPhase1); 135 136 List<Extension> negotiatedExtensionsPhase2; 137 if (transformations.isEmpty()) { 138 negotiatedExtensionsPhase2 = Collections.emptyList(); 139 } else { 140 negotiatedExtensionsPhase2 = new ArrayList<>(transformations.size()); 141 for (Transformation t : transformations) { 142 negotiatedExtensionsPhase2.add(t.getExtensionResponse()); 143 } 144 } 145 146 WsHttpUpgradeHandler wsHandler = 147 req.upgrade(WsHttpUpgradeHandler.class); 148 149 WsHandshakeRequest wsRequest = new WsHandshakeRequest(req, pathParams); 150 WsHandshakeResponse wsResponse = new WsHandshakeResponse(); 151 WsPerSessionServerEndpointConfig perSessionServerEndpointConfig = 152 new WsPerSessionServerEndpointConfig(sec); 153 sec.getConfigurator().modifyHandshake(perSessionServerEndpointConfig, 154 wsRequest, wsResponse); 155 //wsRequest.finished(); 156 157 // Add any additional headers 158 for (Entry<String,List<String>> entry : 159 wsResponse.getHeaders().entrySet()) { 160 for (String headerValue: entry.getValue()) { 161 resp.addHeader(entry.getKey(), headerValue); 162 } 163 } 164 165 Endpoint ep; 166 try { 167 Class<?> clazz = sec.getEndpointClass(); 168 if (Endpoint.class.isAssignableFrom(clazz)) { 169 ep = (Endpoint) sec.getConfigurator().getEndpointInstance( 170 clazz); 171 } else { 172 ep = new PojoEndpointServer(); 173 // Need to make path params available to POJO 174 perSessionServerEndpointConfig.getUserProperties().put( 175 nginx.unit.websocket.pojo.Constants.POJO_PATH_PARAM_KEY, pathParams); 176 } 177 } catch (InstantiationException e) { 178 throw new ServletException(e); 179 } 180 181 wsHandler.preInit(ep, perSessionServerEndpointConfig, sc, wsRequest, 182 negotiatedExtensionsPhase2, subProtocol, null, pathParams, 183 req.isSecure()); 184 185 wsHandler.init(null); 186 } 187 188 createTransformations( List<Extension> negotiatedExtensions)189 private static List<Transformation> createTransformations( 190 List<Extension> negotiatedExtensions) { 191 192 TransformationFactory factory = TransformationFactory.getInstance(); 193 194 LinkedHashMap<String,List<List<Extension.Parameter>>> extensionPreferences = 195 new LinkedHashMap<>(); 196 197 // Result will likely be smaller than this 198 List<Transformation> result = new ArrayList<>(negotiatedExtensions.size()); 199 200 for (Extension extension : negotiatedExtensions) { 201 List<List<Extension.Parameter>> preferences = 202 extensionPreferences.get(extension.getName()); 203 204 if (preferences == null) { 205 preferences = new ArrayList<>(); 206 extensionPreferences.put(extension.getName(), preferences); 207 } 208 209 preferences.add(extension.getParameters()); 210 } 211 212 for (Map.Entry<String,List<List<Extension.Parameter>>> entry : 213 extensionPreferences.entrySet()) { 214 Transformation transformation = factory.create(entry.getKey(), entry.getValue(), true); 215 if (transformation != null) { 216 result.add(transformation); 217 } 218 } 219 return result; 220 } 221 222 append(StringBuilder sb, Extension extension)223 private static void append(StringBuilder sb, Extension extension) { 224 if (extension == null || extension.getName() == null || extension.getName().length() == 0) { 225 return; 226 } 227 228 sb.append(extension.getName()); 229 230 for (Extension.Parameter p : extension.getParameters()) { 231 sb.append(';'); 232 sb.append(p.getName()); 233 if (p.getValue() != null) { 234 sb.append('='); 235 sb.append(p.getValue()); 236 } 237 } 238 } 239 240 241 /* 242 * This only works for tokens. Quoted strings need more sophisticated 243 * parsing. 244 */ headerContainsToken(HttpServletRequest req, String headerName, String target)245 private static boolean headerContainsToken(HttpServletRequest req, 246 String headerName, String target) { 247 Enumeration<String> headers = req.getHeaders(headerName); 248 while (headers.hasMoreElements()) { 249 String header = headers.nextElement(); 250 String[] tokens = header.split(","); 251 for (String token : tokens) { 252 if (target.equalsIgnoreCase(token.trim())) { 253 return true; 254 } 255 } 256 } 257 return false; 258 } 259 260 261 /* 262 * This only works for tokens. Quoted strings need more sophisticated 263 * parsing. 264 */ getTokensFromHeader(HttpServletRequest req, String headerName)265 private static List<String> getTokensFromHeader(HttpServletRequest req, 266 String headerName) { 267 List<String> result = new ArrayList<>(); 268 Enumeration<String> headers = req.getHeaders(headerName); 269 while (headers.hasMoreElements()) { 270 String header = headers.nextElement(); 271 String[] tokens = header.split(","); 272 for (String token : tokens) { 273 result.add(token.trim()); 274 } 275 } 276 return result; 277 } 278 279 getWebSocketAccept(String key)280 private static String getWebSocketAccept(String key) { 281 byte[] digest = ConcurrentMessageDigest.digestSHA1( 282 key.getBytes(StandardCharsets.ISO_8859_1), WS_ACCEPT); 283 return Base64.encodeBase64String(digest); 284 } 285 } 286