1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "remoting/host/security_key/security_key_socket.h"
6 
7 #include <utility>
8 
9 #include "base/bind.h"
10 #include "base/stl_util.h"
11 #include "base/timer/timer.h"
12 #include "net/base/io_buffer.h"
13 #include "net/base/net_errors.h"
14 #include "net/socket/stream_socket.h"
15 #include "net/traffic_annotation/network_traffic_annotation.h"
16 
17 namespace remoting {
18 
19 namespace {
20 
21 const size_t kRequestSizeBytes = 4;
22 const size_t kMaxRequestLength = 16384;
23 const size_t kRequestReadBufferLength = kRequestSizeBytes + kMaxRequestLength;
24 
25 // SSH Failure Code
26 const char kSshError[] = {0x05};
27 
28 }  // namespace
29 
SecurityKeySocket(std::unique_ptr<net::StreamSocket> socket,base::TimeDelta timeout,base::OnceClosure timeout_callback)30 SecurityKeySocket::SecurityKeySocket(std::unique_ptr<net::StreamSocket> socket,
31                                      base::TimeDelta timeout,
32                                      base::OnceClosure timeout_callback)
33     : socket_(std::move(socket)),
34       read_buffer_(base::MakeRefCounted<net::IOBufferWithSize>(
35           kRequestReadBufferLength)) {
36   timer_.reset(new base::OneShotTimer());
37   timer_->Start(FROM_HERE, timeout, std::move(timeout_callback));
38 }
39 
~SecurityKeySocket()40 SecurityKeySocket::~SecurityKeySocket() {
41   DCHECK(thread_checker_.CalledOnValidThread());
42 }
43 
GetAndClearRequestData(std::string * data_out)44 bool SecurityKeySocket::GetAndClearRequestData(std::string* data_out) {
45   DCHECK(thread_checker_.CalledOnValidThread());
46   DCHECK(!waiting_for_request_);
47 
48   if (!IsRequestComplete() || IsRequestTooLarge()) {
49     return false;
50   }
51   // The request size is not part of the data; don't send it.
52   data_out->assign(request_data_.begin() + kRequestSizeBytes,
53                    request_data_.end());
54   request_data_.clear();
55   return true;
56 }
57 
SendResponse(const std::string & response_data)58 void SecurityKeySocket::SendResponse(const std::string& response_data) {
59   DCHECK(thread_checker_.CalledOnValidThread());
60   DCHECK(!write_buffer_);
61 
62   std::string response_length_string = GetResponseLengthAsBytes(response_data);
63   int response_len = response_length_string.size() + response_data.size();
64   std::unique_ptr<std::string> response(
65       new std::string(response_length_string + response_data));
66   write_buffer_ = base::MakeRefCounted<net::DrainableIOBuffer>(
67       base::MakeRefCounted<net::StringIOBuffer>(std::move(response)),
68       response_len);
69 
70   DCHECK(write_buffer_->BytesRemaining());
71   DoWrite();
72 }
73 
SendSshError()74 void SecurityKeySocket::SendSshError() {
75   DCHECK(thread_checker_.CalledOnValidThread());
76 
77   SendResponse(std::string(kSshError, base::size(kSshError)));
78 }
79 
StartReadingRequest(base::OnceClosure request_received_callback)80 void SecurityKeySocket::StartReadingRequest(
81     base::OnceClosure request_received_callback) {
82   DCHECK(thread_checker_.CalledOnValidThread());
83   DCHECK(!request_received_callback_);
84 
85   waiting_for_request_ = true;
86   request_received_callback_ = std::move(request_received_callback);
87 
88   DoRead();
89 }
90 
OnDataWritten(int result)91 void SecurityKeySocket::OnDataWritten(int result) {
92   DCHECK(thread_checker_.CalledOnValidThread());
93   DCHECK(write_buffer_);
94 
95   if (result < 0) {
96     LOG(ERROR) << "Error sending response: " << result;
97     return;
98   }
99   ResetTimer();
100   write_buffer_->DidConsume(result);
101 
102   if (!write_buffer_->BytesRemaining()) {
103     write_buffer_ = nullptr;
104     return;
105   }
106 
107   DoWrite();
108 }
109 
DoWrite()110 void SecurityKeySocket::DoWrite() {
111   DCHECK(thread_checker_.CalledOnValidThread());
112   DCHECK(write_buffer_);
113   net::NetworkTrafficAnnotationTag traffic_annotation =
114       net::DefineNetworkTrafficAnnotation("security_key_socket", R"(
115         semantics {
116           sender: "Chrome Remote Desktop"
117           description:
118             "This request performs the communication between processes when "
119             "handling security key (gnubby) authentication."
120           trigger:
121             "Performing an action (such as signing into a website with "
122             "two-factor authentication enabled) that requires a security key "
123             "touch."
124           data: "Security key protocol data."
125           destination: LOCAL
126         }
127         policy {
128           cookies_allowed: NO
129           setting: "This feature cannot be disabled in Settings."
130           chrome_policy {
131             RemoteAccessHostAllowGnubbyAuth {
132               RemoteAccessHostAllowGnubbyAuth: false
133             }
134           }
135         })");
136   int result = socket_->Write(
137       write_buffer_.get(), write_buffer_->BytesRemaining(),
138       base::BindOnce(&SecurityKeySocket::OnDataWritten, base::Unretained(this)),
139       traffic_annotation);
140   if (result != net::ERR_IO_PENDING) {
141     OnDataWritten(result);
142   }
143 }
144 
OnDataRead(int result)145 void SecurityKeySocket::OnDataRead(int result) {
146   DCHECK(thread_checker_.CalledOnValidThread());
147 
148   if (result <= 0) {
149     if (result < 0) {
150       LOG(ERROR) << "Error reading request: " << result;
151       socket_read_error_ = true;
152     }
153     waiting_for_request_ = false;
154     std::move(request_received_callback_).Run();
155     return;
156   }
157 
158   ResetTimer();
159   // TODO(joedow): If there are multiple requests in a burst, it is possible
160   // that we could read too many bytes from the buffer (e.g. all of request #1
161   // and some of request #2).  We should consider using the request header to
162   // determine the request length and only read that amount from buffer.
163   request_data_.insert(request_data_.end(), read_buffer_->data(),
164                        read_buffer_->data() + result);
165   if (IsRequestComplete()) {
166     waiting_for_request_ = false;
167     std::move(request_received_callback_).Run();
168     return;
169   }
170 
171   DoRead();
172 }
173 
DoRead()174 void SecurityKeySocket::DoRead() {
175   DCHECK(thread_checker_.CalledOnValidThread());
176 
177   int result = socket_->Read(
178       read_buffer_.get(), kRequestReadBufferLength,
179       base::BindOnce(&SecurityKeySocket::OnDataRead, base::Unretained(this)));
180   if (result != net::ERR_IO_PENDING) {
181     OnDataRead(result);
182   }
183 }
184 
IsRequestComplete() const185 bool SecurityKeySocket::IsRequestComplete() const {
186   DCHECK(thread_checker_.CalledOnValidThread());
187 
188   if (request_data_.size() < kRequestSizeBytes) {
189     return false;
190   }
191   return GetRequestLength() <= request_data_.size();
192 }
193 
IsRequestTooLarge() const194 bool SecurityKeySocket::IsRequestTooLarge() const {
195   DCHECK(thread_checker_.CalledOnValidThread());
196 
197   if (request_data_.size() < kRequestSizeBytes) {
198     return false;
199   }
200   return GetRequestLength() > kMaxRequestLength;
201 }
202 
GetRequestLength() const203 size_t SecurityKeySocket::GetRequestLength() const {
204   DCHECK(request_data_.size() >= kRequestSizeBytes);
205 
206   return ((request_data_[0] & 255) << 24) + ((request_data_[1] & 255) << 16) +
207          ((request_data_[2] & 255) << 8) + (request_data_[3] & 255) +
208          kRequestSizeBytes;
209 }
210 
GetResponseLengthAsBytes(const std::string & response) const211 std::string SecurityKeySocket::GetResponseLengthAsBytes(
212     const std::string& response) const {
213   std::string response_len;
214   response_len.reserve(kRequestSizeBytes);
215   int len = response.size();
216 
217   response_len.push_back((len >> 24) & 255);
218   response_len.push_back((len >> 16) & 255);
219   response_len.push_back((len >> 8) & 255);
220   response_len.push_back(len & 255);
221 
222   return response_len;
223 }
224 
ResetTimer()225 void SecurityKeySocket::ResetTimer() {
226   if (timer_->IsRunning()) {
227     timer_->Reset();
228   }
229 }
230 
231 }  // namespace remoting
232