1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_handshake_stream_base.h"
6 
7 #include <unordered_set>
8 
9 #include "base/metrics/histogram_macros.h"
10 #include "base/strings/string_util.h"
11 #include "net/http/http_request_headers.h"
12 #include "net/http/http_response_headers.h"
13 #include "net/websockets/websocket_extension.h"
14 #include "net/websockets/websocket_extension_parser.h"
15 #include "net/websockets/websocket_handshake_constants.h"
16 
17 namespace net {
18 
19 // static
MultipleHeaderValuesMessage(const std::string & header_name)20 std::string WebSocketHandshakeStreamBase::MultipleHeaderValuesMessage(
21     const std::string& header_name) {
22   return std::string("'") + header_name +
23          "' header must not appear more than once in a response";
24 }
25 
26 // static
AddVectorHeaderIfNonEmpty(const char * name,const std::vector<std::string> & value,HttpRequestHeaders * headers)27 void WebSocketHandshakeStreamBase::AddVectorHeaderIfNonEmpty(
28     const char* name,
29     const std::vector<std::string>& value,
30     HttpRequestHeaders* headers) {
31   if (value.empty())
32     return;
33   headers->SetHeader(name, base::JoinString(value, ", "));
34 }
35 
36 // static
ValidateSubProtocol(const HttpResponseHeaders * headers,const std::vector<std::string> & requested_sub_protocols,std::string * sub_protocol,std::string * failure_message)37 bool WebSocketHandshakeStreamBase::ValidateSubProtocol(
38     const HttpResponseHeaders* headers,
39     const std::vector<std::string>& requested_sub_protocols,
40     std::string* sub_protocol,
41     std::string* failure_message) {
42   size_t iter = 0;
43   std::string value;
44   std::unordered_set<std::string> requested_set(requested_sub_protocols.begin(),
45                                                 requested_sub_protocols.end());
46   int count = 0;
47   bool has_multiple_protocols = false;
48   bool has_invalid_protocol = false;
49 
50   while (!has_invalid_protocol || !has_multiple_protocols) {
51     std::string temp_value;
52     if (!headers->EnumerateHeader(&iter, websockets::kSecWebSocketProtocol,
53                                   &temp_value))
54       break;
55     value = temp_value;
56     if (requested_set.count(value) == 0)
57       has_invalid_protocol = true;
58     if (++count > 1)
59       has_multiple_protocols = true;
60   }
61 
62   if (has_multiple_protocols) {
63     *failure_message =
64         MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol);
65     return false;
66   } else if (count > 0 && requested_sub_protocols.size() == 0) {
67     *failure_message = std::string(
68                            "Response must not include 'Sec-WebSocket-Protocol' "
69                            "header if not present in request: ") +
70                        value;
71     return false;
72   } else if (has_invalid_protocol) {
73     *failure_message = "'Sec-WebSocket-Protocol' header value '" + value +
74                        "' in response does not match any of sent values";
75     return false;
76   } else if (requested_sub_protocols.size() > 0 && count == 0) {
77     *failure_message =
78         "Sent non-empty 'Sec-WebSocket-Protocol' header "
79         "but no response was received";
80     return false;
81   }
82   *sub_protocol = value;
83   return true;
84 }
85 
86 // static
ValidateExtensions(const HttpResponseHeaders * headers,std::string * accepted_extensions_descriptor,std::string * failure_message,WebSocketExtensionParams * params)87 bool WebSocketHandshakeStreamBase::ValidateExtensions(
88     const HttpResponseHeaders* headers,
89     std::string* accepted_extensions_descriptor,
90     std::string* failure_message,
91     WebSocketExtensionParams* params) {
92   size_t iter = 0;
93   std::string header_value;
94   std::vector<std::string> header_values;
95   // TODO(ricea): If adding support for additional extensions, generalise this
96   // code.
97   bool seen_permessage_deflate = false;
98   while (headers->EnumerateHeader(&iter, websockets::kSecWebSocketExtensions,
99                                   &header_value)) {
100     WebSocketExtensionParser parser;
101     if (!parser.Parse(header_value)) {
102       // TODO(yhirano) Set appropriate failure message.
103       *failure_message =
104           "'Sec-WebSocket-Extensions' header value is "
105           "rejected by the parser: " +
106           header_value;
107       return false;
108     }
109 
110     const std::vector<WebSocketExtension>& extensions = parser.extensions();
111     for (const auto& extension : extensions) {
112       if (extension.name() == "permessage-deflate") {
113         if (seen_permessage_deflate) {
114           *failure_message = "Received duplicate permessage-deflate response";
115           return false;
116         }
117         seen_permessage_deflate = true;
118         auto& deflate_parameters = params->deflate_parameters;
119         if (!deflate_parameters.Initialize(extension, failure_message) ||
120             !deflate_parameters.IsValidAsResponse(failure_message)) {
121           *failure_message = "Error in permessage-deflate: " + *failure_message;
122           return false;
123         }
124         // Note that we don't have to check the request-response compatibility
125         // here because we send a request compatible with any valid responses.
126         // TODO(yhirano): Place a DCHECK here.
127 
128         header_values.push_back(header_value);
129       } else {
130         *failure_message = "Found an unsupported extension '" +
131                            extension.name() +
132                            "' in 'Sec-WebSocket-Extensions' header";
133         return false;
134       }
135     }
136   }
137   *accepted_extensions_descriptor = base::JoinString(header_values, ", ");
138   params->deflate_enabled = seen_permessage_deflate;
139   return true;
140 }
141 
RecordHandshakeResult(HandshakeResult result)142 void WebSocketHandshakeStreamBase::RecordHandshakeResult(
143     HandshakeResult result) {
144   UMA_HISTOGRAM_ENUMERATION("Net.WebSocket.HandshakeResult2", result,
145                             HandshakeResult::NUM_HANDSHAKE_RESULT_TYPES);
146 }
147 
148 }  // namespace net
149