1 // Copyright (c) 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chrome/test/chromedriver/server/http_server.h"
6 
7 namespace {
8 
9 // Maximum message size between app and ChromeDriver. Data larger than 150 MB
10 // or so can cause crashes in Chrome (https://crbug.com/890854), so there is no
11 // need to support messages that are too large.
12 const int kBufferSize = 256 * 1024 * 1024;  // 256 MB
13 
ListenOnIPv4(net::ServerSocket * socket,uint16_t port,bool allow_remote)14 int ListenOnIPv4(net::ServerSocket* socket, uint16_t port, bool allow_remote) {
15   std::string binding_ip = net::IPAddress::IPv4Localhost().ToString();
16   if (allow_remote)
17     binding_ip = net::IPAddress::IPv4AllZeros().ToString();
18   return socket->ListenWithAddressAndPort(binding_ip, port, 5);
19 }
20 
ListenOnIPv6(net::ServerSocket * socket,uint16_t port,bool allow_remote)21 int ListenOnIPv6(net::ServerSocket* socket, uint16_t port, bool allow_remote) {
22   std::string binding_ip = net::IPAddress::IPv6Localhost().ToString();
23   if (allow_remote)
24     binding_ip = net::IPAddress::IPv6AllZeros().ToString();
25   return socket->ListenWithAddressAndPort(binding_ip, port, 5);
26 }
27 
RequestIsSafeToServe(const net::HttpServerRequestInfo & info,bool allow_remote,const std::vector<net::IPAddress> & whitelisted_ips)28 bool RequestIsSafeToServe(const net::HttpServerRequestInfo& info,
29                           bool allow_remote,
30                           const std::vector<net::IPAddress>& whitelisted_ips) {
31   // To guard against browser-originating cross-site requests, when host header
32   // and/or origin header are present, serve only those coming from localhost
33   // or from an explicitly whitelisted ip.
34   std::string origin_header = info.GetHeaderValue("origin");
35   bool local_origin = false;
36   if (!origin_header.empty()) {
37     GURL url = GURL(origin_header);
38     local_origin = net::IsLocalhost(url);
39     if (!local_origin) {
40       if (!allow_remote) {
41         LOG(ERROR)
42             << "Remote connections not allowed; rejecting request with origin: "
43             << origin_header;
44         return false;
45       }
46       if (!whitelisted_ips.empty()) {
47         net::IPAddress address = net::IPAddress();
48         if (!ParseURLHostnameToAddress(origin_header, &address)) {
49           LOG(ERROR) << "Unable to parse origin to IPAddress: "
50                      << origin_header;
51           return false;
52         }
53         if (!base::Contains(whitelisted_ips, address)) {
54           LOG(ERROR) << "Rejecting request with origin: " << origin_header;
55           return false;
56         }
57       }
58     }
59   }
60   // TODO https://crbug.com/chromedriver/3389
61   //  When remote access is allowed and origin is not specified,
62   // we should confirm that host is current machines ip or hostname
63 
64   if (local_origin || !allow_remote) {
65     // when origin is localhost host must be localhost
66     // when origin is not set, and no remote access, host must be localhost
67     std::string host_header = info.GetHeaderValue("host");
68     if (!host_header.empty()) {
69       GURL url = GURL("http://" + host_header);
70       if (!net::IsLocalhost(url)) {
71         LOG(ERROR) << "Rejecting request with host: " << host_header
72                    << ". origin is " << origin_header;
73         return false;
74       }
75     }
76   }
77   return true;
78 }
79 
80 }  // namespace
81 
HttpServer(const std::string & url_base,const std::vector<net::IPAddress> & whitelisted_ips,const HttpRequestHandlerFunc & handle_request_func,base::WeakPtr<HttpHandler> handler,scoped_refptr<base::SingleThreadTaskRunner> cmd_runner)82 HttpServer::HttpServer(const std::string& url_base,
83                        const std::vector<net::IPAddress>& whitelisted_ips,
84                        const HttpRequestHandlerFunc& handle_request_func,
85                        base::WeakPtr<HttpHandler> handler,
86                        scoped_refptr<base::SingleThreadTaskRunner> cmd_runner)
87     : url_base_(url_base),
88       handle_request_func_(handle_request_func),
89       allow_remote_(false),
90       whitelisted_ips_(whitelisted_ips),
91       handler_(handler),
92       cmd_runner_(cmd_runner) {}
93 
Start(uint16_t port,bool allow_remote,bool use_ipv4)94 int HttpServer::Start(uint16_t port, bool allow_remote, bool use_ipv4) {
95   allow_remote_ = allow_remote;
96   std::unique_ptr<net::ServerSocket> server_socket(
97       new net::TCPServerSocket(nullptr, net::NetLogSource()));
98   int status = use_ipv4 ? ListenOnIPv4(server_socket.get(), port, allow_remote)
99                         : ListenOnIPv6(server_socket.get(), port, allow_remote);
100   if (status != net::OK) {
101     VLOG(0) << "listen on " << (use_ipv4 ? "IPv4" : "IPv6")
102             << " failed with error " << net::ErrorToShortString(status);
103     return status;
104   }
105   server_ = std::make_unique<net::HttpServer>(std::move(server_socket), this);
106   net::IPEndPoint address;
107   return server_->GetLocalAddress(&address);
108 }
109 
OnConnect(int connection_id)110 void HttpServer::OnConnect(int connection_id) {
111   server_->SetSendBufferSize(connection_id, kBufferSize);
112   server_->SetReceiveBufferSize(connection_id, kBufferSize);
113 }
114 
OnHttpRequest(int connection_id,const net::HttpServerRequestInfo & info)115 void HttpServer::OnHttpRequest(int connection_id,
116                                const net::HttpServerRequestInfo& info) {
117   if (!RequestIsSafeToServe(info, allow_remote_, whitelisted_ips_)) {
118     server_->Send500(connection_id,
119                      "Host header or origin header is specified and is not "
120                      "whitelisted or localhost.",
121                      TRAFFIC_ANNOTATION_FOR_TESTS);
122     return;
123   }
124   handle_request_func_.Run(
125       info, base::BindRepeating(&HttpServer::OnResponse,
126                                 weak_factory_.GetWeakPtr(), connection_id,
127                                 !info.HasHeaderValue("connection", "close")));
128 }
129 
130 HttpServer::~HttpServer() = default;
131 
OnWebSocketRequest(int connection_id,const net::HttpServerRequestInfo & info)132 void HttpServer::OnWebSocketRequest(int connection_id,
133                                     const net::HttpServerRequestInfo& info) {
134   cmd_runner_->PostTask(
135       FROM_HERE, base::BindOnce(&HttpHandler::OnWebSocketRequest, handler_,
136                                 this, connection_id, info));
137 }
138 
OnWebSocketMessage(int connection_id,std::string data)139 void HttpServer::OnWebSocketMessage(int connection_id, std::string data) {
140   // TODO: Make use of WebSocket data
141   VLOG(0) << "HttpServer::OnWebSocketMessage received: " << data;
142 }
143 
OnClose(int connection_id)144 void HttpServer::OnClose(int connection_id) {
145   cmd_runner_->PostTask(
146       FROM_HERE,
147       base::BindOnce(&HttpHandler::OnClose, handler_, this, connection_id));
148 }
149 
AcceptWebSocket(int connection_id,const net::HttpServerRequestInfo & request)150 void HttpServer::AcceptWebSocket(int connection_id,
151                                  const net::HttpServerRequestInfo& request) {
152   server_->AcceptWebSocket(connection_id, request,
153                            TRAFFIC_ANNOTATION_FOR_TESTS);
154 }
155 
SendResponse(int connection_id,const net::HttpServerResponseInfo & response,const net::NetworkTrafficAnnotationTag & traffic_annotation)156 void HttpServer::SendResponse(
157     int connection_id,
158     const net::HttpServerResponseInfo& response,
159     const net::NetworkTrafficAnnotationTag& traffic_annotation) {
160   server_->SendResponse(connection_id, response, traffic_annotation);
161 }
162 
OnResponse(int connection_id,bool keep_alive,std::unique_ptr<net::HttpServerResponseInfo> response)163 void HttpServer::OnResponse(
164     int connection_id,
165     bool keep_alive,
166     std::unique_ptr<net::HttpServerResponseInfo> response) {
167   if (!keep_alive)
168     response->AddHeader("Connection", "close");
169   server_->SendResponse(connection_id, *response, TRAFFIC_ANNOTATION_FOR_TESTS);
170   // Don't need to call server_->Close(), since SendResponse() will handle
171   // this for us.
172 }
173