1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/socks5_client_socket.h"
6 
7 #if defined(OS_BSD)
8 #include <netinet/in.h>
9 #endif
10 
11 #include <utility>
12 
13 #include "base/bind.h"
14 #include "base/callback_helpers.h"
15 #include "base/compiler_specific.h"
16 #include "base/format_macros.h"
17 #include "base/stl_util.h"
18 #include "base/strings/string_util.h"
19 #include "base/sys_byteorder.h"
20 #include "base/trace_event/trace_event.h"
21 #include "net/base/io_buffer.h"
22 #include "net/log/net_log.h"
23 #include "net/log/net_log_event_type.h"
24 #include "net/traffic_annotation/network_traffic_annotation.h"
25 
26 namespace net {
27 
28 const unsigned int SOCKS5ClientSocket::kGreetReadHeaderSize = 2;
29 const unsigned int SOCKS5ClientSocket::kWriteHeaderSize = 10;
30 const unsigned int SOCKS5ClientSocket::kReadHeaderSize = 5;
31 const uint8_t SOCKS5ClientSocket::kSOCKS5Version = 0x05;
32 const uint8_t SOCKS5ClientSocket::kTunnelCommand = 0x01;
33 const uint8_t SOCKS5ClientSocket::kNullByte = 0x00;
34 
35 static_assert(sizeof(struct in_addr) == 4, "incorrect system size of IPv4");
36 static_assert(sizeof(struct in6_addr) == 16, "incorrect system size of IPv6");
37 
SOCKS5ClientSocket(std::unique_ptr<StreamSocket> transport_socket,const HostPortPair & destination,const NetworkTrafficAnnotationTag & traffic_annotation)38 SOCKS5ClientSocket::SOCKS5ClientSocket(
39     std::unique_ptr<StreamSocket> transport_socket,
40     const HostPortPair& destination,
41     const NetworkTrafficAnnotationTag& traffic_annotation)
42     : io_callback_(base::BindRepeating(&SOCKS5ClientSocket::OnIOComplete,
43                                        base::Unretained(this))),
44       transport_socket_(std::move(transport_socket)),
45       next_state_(STATE_NONE),
46       completed_handshake_(false),
47       bytes_sent_(0),
48       bytes_received_(0),
49       read_header_size(kReadHeaderSize),
50       was_ever_used_(false),
51       destination_(destination),
52       net_log_(transport_socket_->NetLog()),
53       traffic_annotation_(traffic_annotation) {}
54 
~SOCKS5ClientSocket()55 SOCKS5ClientSocket::~SOCKS5ClientSocket() {
56   Disconnect();
57 }
58 
Connect(CompletionOnceCallback callback)59 int SOCKS5ClientSocket::Connect(CompletionOnceCallback callback) {
60   DCHECK(transport_socket_);
61   DCHECK_EQ(STATE_NONE, next_state_);
62   DCHECK(user_callback_.is_null());
63 
64   // If already connected, then just return OK.
65   if (completed_handshake_)
66     return OK;
67 
68   net_log_.BeginEvent(NetLogEventType::SOCKS5_CONNECT);
69 
70   next_state_ = STATE_GREET_WRITE;
71   buffer_.clear();
72 
73   int rv = DoLoop(OK);
74   if (rv == ERR_IO_PENDING) {
75     user_callback_ = std::move(callback);
76   } else {
77     net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_CONNECT, rv);
78   }
79   return rv;
80 }
81 
Disconnect()82 void SOCKS5ClientSocket::Disconnect() {
83   completed_handshake_ = false;
84   transport_socket_->Disconnect();
85 
86   // Reset other states to make sure they aren't mistakenly used later.
87   // These are the states initialized by Connect().
88   next_state_ = STATE_NONE;
89   user_callback_.Reset();
90 }
91 
IsConnected() const92 bool SOCKS5ClientSocket::IsConnected() const {
93   return completed_handshake_ && transport_socket_->IsConnected();
94 }
95 
IsConnectedAndIdle() const96 bool SOCKS5ClientSocket::IsConnectedAndIdle() const {
97   return completed_handshake_ && transport_socket_->IsConnectedAndIdle();
98 }
99 
NetLog() const100 const NetLogWithSource& SOCKS5ClientSocket::NetLog() const {
101   return net_log_;
102 }
103 
WasEverUsed() const104 bool SOCKS5ClientSocket::WasEverUsed() const {
105   return was_ever_used_;
106 }
107 
WasAlpnNegotiated() const108 bool SOCKS5ClientSocket::WasAlpnNegotiated() const {
109   if (transport_socket_)
110     return transport_socket_->WasAlpnNegotiated();
111   NOTREACHED();
112   return false;
113 }
114 
GetNegotiatedProtocol() const115 NextProto SOCKS5ClientSocket::GetNegotiatedProtocol() const {
116   if (transport_socket_)
117     return transport_socket_->GetNegotiatedProtocol();
118   NOTREACHED();
119   return kProtoUnknown;
120 }
121 
GetSSLInfo(SSLInfo * ssl_info)122 bool SOCKS5ClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
123   if (transport_socket_)
124     return transport_socket_->GetSSLInfo(ssl_info);
125   NOTREACHED();
126   return false;
127 }
128 
GetConnectionAttempts(ConnectionAttempts * out) const129 void SOCKS5ClientSocket::GetConnectionAttempts(ConnectionAttempts* out) const {
130   out->clear();
131 }
132 
GetTotalReceivedBytes() const133 int64_t SOCKS5ClientSocket::GetTotalReceivedBytes() const {
134   return transport_socket_->GetTotalReceivedBytes();
135 }
136 
ApplySocketTag(const SocketTag & tag)137 void SOCKS5ClientSocket::ApplySocketTag(const SocketTag& tag) {
138   return transport_socket_->ApplySocketTag(tag);
139 }
140 
141 // Read is called by the transport layer above to read. This can only be done
142 // if the SOCKS handshake is complete.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)143 int SOCKS5ClientSocket::Read(IOBuffer* buf,
144                              int buf_len,
145                              CompletionOnceCallback callback) {
146   DCHECK(completed_handshake_);
147   DCHECK_EQ(STATE_NONE, next_state_);
148   DCHECK(user_callback_.is_null());
149   DCHECK(!callback.is_null());
150 
151   int rv = transport_socket_->Read(
152       buf, buf_len,
153       base::BindOnce(&SOCKS5ClientSocket::OnReadWriteComplete,
154                      base::Unretained(this), std::move(callback)));
155   if (rv > 0)
156     was_ever_used_ = true;
157   return rv;
158 }
159 
160 // Write is called by the transport layer. This can only be done if the
161 // SOCKS handshake is complete.
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)162 int SOCKS5ClientSocket::Write(
163     IOBuffer* buf,
164     int buf_len,
165     CompletionOnceCallback callback,
166     const NetworkTrafficAnnotationTag& traffic_annotation) {
167   DCHECK(completed_handshake_);
168   DCHECK_EQ(STATE_NONE, next_state_);
169   DCHECK(user_callback_.is_null());
170   DCHECK(!callback.is_null());
171 
172   int rv = transport_socket_->Write(
173       buf, buf_len,
174       base::BindOnce(&SOCKS5ClientSocket::OnReadWriteComplete,
175                      base::Unretained(this), std::move(callback)),
176       traffic_annotation);
177   if (rv > 0)
178     was_ever_used_ = true;
179   return rv;
180 }
181 
SetReceiveBufferSize(int32_t size)182 int SOCKS5ClientSocket::SetReceiveBufferSize(int32_t size) {
183   return transport_socket_->SetReceiveBufferSize(size);
184 }
185 
SetSendBufferSize(int32_t size)186 int SOCKS5ClientSocket::SetSendBufferSize(int32_t size) {
187   return transport_socket_->SetSendBufferSize(size);
188 }
189 
DoCallback(int result)190 void SOCKS5ClientSocket::DoCallback(int result) {
191   DCHECK_NE(ERR_IO_PENDING, result);
192   DCHECK(!user_callback_.is_null());
193 
194   // Since Run() may result in Read being called,
195   // clear user_callback_ up front.
196   std::move(user_callback_).Run(result);
197 }
198 
OnIOComplete(int result)199 void SOCKS5ClientSocket::OnIOComplete(int result) {
200   DCHECK_NE(STATE_NONE, next_state_);
201   int rv = DoLoop(result);
202   if (rv != ERR_IO_PENDING) {
203     net_log_.EndEvent(NetLogEventType::SOCKS5_CONNECT);
204     DoCallback(rv);
205   }
206 }
207 
OnReadWriteComplete(CompletionOnceCallback callback,int result)208 void SOCKS5ClientSocket::OnReadWriteComplete(CompletionOnceCallback callback,
209                                              int result) {
210   DCHECK_NE(ERR_IO_PENDING, result);
211   DCHECK(!callback.is_null());
212 
213   if (result > 0)
214     was_ever_used_ = true;
215   std::move(callback).Run(result);
216 }
217 
DoLoop(int last_io_result)218 int SOCKS5ClientSocket::DoLoop(int last_io_result) {
219   DCHECK_NE(next_state_, STATE_NONE);
220   int rv = last_io_result;
221   do {
222     State state = next_state_;
223     next_state_ = STATE_NONE;
224     switch (state) {
225       case STATE_GREET_WRITE:
226         DCHECK_EQ(OK, rv);
227         net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_WRITE);
228         rv = DoGreetWrite();
229         break;
230       case STATE_GREET_WRITE_COMPLETE:
231         rv = DoGreetWriteComplete(rv);
232         net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_WRITE,
233                                           rv);
234         break;
235       case STATE_GREET_READ:
236         DCHECK_EQ(OK, rv);
237         net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_READ);
238         rv = DoGreetRead();
239         break;
240       case STATE_GREET_READ_COMPLETE:
241         rv = DoGreetReadComplete(rv);
242         net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_READ,
243                                           rv);
244         break;
245       case STATE_HANDSHAKE_WRITE:
246         DCHECK_EQ(OK, rv);
247         net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_WRITE);
248         rv = DoHandshakeWrite();
249         break;
250       case STATE_HANDSHAKE_WRITE_COMPLETE:
251         rv = DoHandshakeWriteComplete(rv);
252         net_log_.EndEventWithNetErrorCode(
253             NetLogEventType::SOCKS5_HANDSHAKE_WRITE, rv);
254         break;
255       case STATE_HANDSHAKE_READ:
256         DCHECK_EQ(OK, rv);
257         net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_READ);
258         rv = DoHandshakeRead();
259         break;
260       case STATE_HANDSHAKE_READ_COMPLETE:
261         rv = DoHandshakeReadComplete(rv);
262         net_log_.EndEventWithNetErrorCode(
263             NetLogEventType::SOCKS5_HANDSHAKE_READ, rv);
264         break;
265       default:
266         NOTREACHED() << "bad state";
267         rv = ERR_UNEXPECTED;
268         break;
269     }
270   } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
271   return rv;
272 }
273 
274 const char kSOCKS5GreetWriteData[] = { 0x05, 0x01, 0x00 };  // no authentication
275 
DoGreetWrite()276 int SOCKS5ClientSocket::DoGreetWrite() {
277   // Since we only have 1 byte to send the hostname length in, if the
278   // URL has a hostname longer than 255 characters we can't send it.
279   if (0xFF < destination_.host().size()) {
280     net_log_.AddEvent(NetLogEventType::SOCKS_HOSTNAME_TOO_BIG);
281     return ERR_SOCKS_CONNECTION_FAILED;
282   }
283 
284   if (buffer_.empty()) {
285     buffer_ =
286         std::string(kSOCKS5GreetWriteData, base::size(kSOCKS5GreetWriteData));
287     bytes_sent_ = 0;
288   }
289 
290   next_state_ = STATE_GREET_WRITE_COMPLETE;
291   size_t handshake_buf_len = buffer_.size() - bytes_sent_;
292   handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
293   memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_],
294          handshake_buf_len);
295   return transport_socket_->Write(handshake_buf_.get(), handshake_buf_len,
296                                   io_callback_, traffic_annotation_);
297 }
298 
DoGreetWriteComplete(int result)299 int SOCKS5ClientSocket::DoGreetWriteComplete(int result) {
300   if (result < 0)
301     return result;
302 
303   bytes_sent_ += result;
304   if (bytes_sent_ == buffer_.size()) {
305     buffer_.clear();
306     bytes_received_ = 0;
307     next_state_ = STATE_GREET_READ;
308   } else {
309     next_state_ = STATE_GREET_WRITE;
310   }
311   return OK;
312 }
313 
DoGreetRead()314 int SOCKS5ClientSocket::DoGreetRead() {
315   next_state_ = STATE_GREET_READ_COMPLETE;
316   size_t handshake_buf_len = kGreetReadHeaderSize - bytes_received_;
317   handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
318   return transport_socket_->Read(handshake_buf_.get(), handshake_buf_len,
319                                  io_callback_);
320 }
321 
DoGreetReadComplete(int result)322 int SOCKS5ClientSocket::DoGreetReadComplete(int result) {
323   if (result < 0)
324     return result;
325 
326   if (result == 0) {
327     net_log_.AddEvent(
328         NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING);
329     return ERR_SOCKS_CONNECTION_FAILED;
330   }
331 
332   bytes_received_ += result;
333   buffer_.append(handshake_buf_->data(), result);
334   if (bytes_received_ < kGreetReadHeaderSize) {
335     next_state_ = STATE_GREET_READ;
336     return OK;
337   }
338 
339   // Got the greet data.
340   if (buffer_[0] != kSOCKS5Version) {
341     net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
342                                    "version", buffer_[0]);
343     return ERR_SOCKS_CONNECTION_FAILED;
344   }
345   if (buffer_[1] != 0x00) {
346     net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_AUTH,
347                                    "method", buffer_[1]);
348     return ERR_SOCKS_CONNECTION_FAILED;
349   }
350 
351   buffer_.clear();
352   next_state_ = STATE_HANDSHAKE_WRITE;
353   return OK;
354 }
355 
BuildHandshakeWriteBuffer(std::string * handshake) const356 int SOCKS5ClientSocket::BuildHandshakeWriteBuffer(std::string* handshake)
357     const {
358   DCHECK(handshake->empty());
359 
360   handshake->push_back(kSOCKS5Version);
361   handshake->push_back(kTunnelCommand);  // Connect command
362   handshake->push_back(kNullByte);  // Reserved null
363 
364   handshake->push_back(kEndPointDomain);  // The type of the address.
365 
366   DCHECK_GE(static_cast<size_t>(0xFF), destination_.host().size());
367 
368   // First add the size of the hostname, followed by the hostname.
369   handshake->push_back(static_cast<unsigned char>(destination_.host().size()));
370   handshake->append(destination_.host());
371 
372   uint16_t nw_port = base::HostToNet16(destination_.port());
373   handshake->append(reinterpret_cast<char*>(&nw_port), sizeof(nw_port));
374   return OK;
375 }
376 
377 // Writes the SOCKS handshake data to the underlying socket connection.
DoHandshakeWrite()378 int SOCKS5ClientSocket::DoHandshakeWrite() {
379   next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
380 
381   if (buffer_.empty()) {
382     int rv = BuildHandshakeWriteBuffer(&buffer_);
383     if (rv != OK)
384       return rv;
385     bytes_sent_ = 0;
386   }
387 
388   int handshake_buf_len = buffer_.size() - bytes_sent_;
389   DCHECK_LT(0, handshake_buf_len);
390   handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
391   memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
392          handshake_buf_len);
393   return transport_socket_->Write(handshake_buf_.get(), handshake_buf_len,
394                                   io_callback_, traffic_annotation_);
395 }
396 
DoHandshakeWriteComplete(int result)397 int SOCKS5ClientSocket::DoHandshakeWriteComplete(int result) {
398   if (result < 0)
399     return result;
400 
401   // We ignore the case when result is 0, since the underlying Write
402   // may return spurious writes while waiting on the socket.
403 
404   bytes_sent_ += result;
405   if (bytes_sent_ == buffer_.size()) {
406     next_state_ = STATE_HANDSHAKE_READ;
407     buffer_.clear();
408   } else if (bytes_sent_ < buffer_.size()) {
409     next_state_ = STATE_HANDSHAKE_WRITE;
410   } else {
411     NOTREACHED();
412   }
413 
414   return OK;
415 }
416 
DoHandshakeRead()417 int SOCKS5ClientSocket::DoHandshakeRead() {
418   next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
419 
420   if (buffer_.empty()) {
421     bytes_received_ = 0;
422     read_header_size = kReadHeaderSize;
423   }
424 
425   int handshake_buf_len = read_header_size - bytes_received_;
426   handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
427   return transport_socket_->Read(handshake_buf_.get(), handshake_buf_len,
428                                  io_callback_);
429 }
430 
DoHandshakeReadComplete(int result)431 int SOCKS5ClientSocket::DoHandshakeReadComplete(int result) {
432   if (result < 0)
433     return result;
434 
435   // The underlying socket closed unexpectedly.
436   if (result == 0) {
437     net_log_.AddEvent(
438         NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE);
439     return ERR_SOCKS_CONNECTION_FAILED;
440   }
441 
442   buffer_.append(handshake_buf_->data(), result);
443   bytes_received_ += result;
444 
445   // When the first few bytes are read, check how many more are required
446   // and accordingly increase them
447   if (bytes_received_ == kReadHeaderSize) {
448     if (buffer_[0] != kSOCKS5Version || buffer_[2] != kNullByte) {
449       net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
450                                      "version", buffer_[0]);
451       return ERR_SOCKS_CONNECTION_FAILED;
452     }
453     if (buffer_[1] != 0x00) {
454       net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_SERVER_ERROR,
455                                      "error_code", buffer_[1]);
456       return ERR_SOCKS_CONNECTION_FAILED;
457     }
458 
459     // We check the type of IP/Domain the server returns and accordingly
460     // increase the size of the response. For domains, we need to read the
461     // size of the domain, so the initial request size is upto the domain
462     // size. Since for IPv4/IPv6 the size is fixed and hence no 'size' is
463     // read, we substract 1 byte from the additional request size.
464     SocksEndPointAddressType address_type =
465         static_cast<SocksEndPointAddressType>(buffer_[3]);
466     if (address_type == kEndPointDomain) {
467       read_header_size += static_cast<uint8_t>(buffer_[4]);
468     } else if (address_type == kEndPointResolvedIPv4) {
469       read_header_size += sizeof(struct in_addr) - 1;
470     } else if (address_type == kEndPointResolvedIPv6) {
471       read_header_size += sizeof(struct in6_addr) - 1;
472     } else {
473       net_log_.AddEventWithIntParams(
474           NetLogEventType::SOCKS_UNKNOWN_ADDRESS_TYPE, "address_type",
475           buffer_[3]);
476       return ERR_SOCKS_CONNECTION_FAILED;
477     }
478 
479     read_header_size += 2;  // for the port.
480     next_state_ = STATE_HANDSHAKE_READ;
481     return OK;
482   }
483 
484   // When the final bytes are read, setup handshake. We ignore the rest
485   // of the response since they represent the SOCKSv5 endpoint and have
486   // no use when doing a tunnel connection.
487   if (bytes_received_ == read_header_size) {
488     completed_handshake_ = true;
489     buffer_.clear();
490     next_state_ = STATE_NONE;
491     return OK;
492   }
493 
494   next_state_ = STATE_HANDSHAKE_READ;
495   return OK;
496 }
497 
GetPeerAddress(IPEndPoint * address) const498 int SOCKS5ClientSocket::GetPeerAddress(IPEndPoint* address) const {
499   return transport_socket_->GetPeerAddress(address);
500 }
501 
GetLocalAddress(IPEndPoint * address) const502 int SOCKS5ClientSocket::GetLocalAddress(IPEndPoint* address) const {
503   return transport_socket_->GetLocalAddress(address);
504 }
505 
506 }  // namespace net
507