1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/socks5_client_socket.h"
6
7 #if defined(OS_BSD)
8 #include <netinet/in.h>
9 #endif
10
11 #include <utility>
12
13 #include "base/bind.h"
14 #include "base/callback_helpers.h"
15 #include "base/compiler_specific.h"
16 #include "base/format_macros.h"
17 #include "base/stl_util.h"
18 #include "base/strings/string_util.h"
19 #include "base/sys_byteorder.h"
20 #include "base/trace_event/trace_event.h"
21 #include "net/base/io_buffer.h"
22 #include "net/log/net_log.h"
23 #include "net/log/net_log_event_type.h"
24 #include "net/traffic_annotation/network_traffic_annotation.h"
25
26 namespace net {
27
28 const unsigned int SOCKS5ClientSocket::kGreetReadHeaderSize = 2;
29 const unsigned int SOCKS5ClientSocket::kWriteHeaderSize = 10;
30 const unsigned int SOCKS5ClientSocket::kReadHeaderSize = 5;
31 const uint8_t SOCKS5ClientSocket::kSOCKS5Version = 0x05;
32 const uint8_t SOCKS5ClientSocket::kTunnelCommand = 0x01;
33 const uint8_t SOCKS5ClientSocket::kNullByte = 0x00;
34
35 static_assert(sizeof(struct in_addr) == 4, "incorrect system size of IPv4");
36 static_assert(sizeof(struct in6_addr) == 16, "incorrect system size of IPv6");
37
SOCKS5ClientSocket(std::unique_ptr<StreamSocket> transport_socket,const HostPortPair & destination,const NetworkTrafficAnnotationTag & traffic_annotation)38 SOCKS5ClientSocket::SOCKS5ClientSocket(
39 std::unique_ptr<StreamSocket> transport_socket,
40 const HostPortPair& destination,
41 const NetworkTrafficAnnotationTag& traffic_annotation)
42 : io_callback_(base::BindRepeating(&SOCKS5ClientSocket::OnIOComplete,
43 base::Unretained(this))),
44 transport_socket_(std::move(transport_socket)),
45 next_state_(STATE_NONE),
46 completed_handshake_(false),
47 bytes_sent_(0),
48 bytes_received_(0),
49 read_header_size(kReadHeaderSize),
50 was_ever_used_(false),
51 destination_(destination),
52 net_log_(transport_socket_->NetLog()),
53 traffic_annotation_(traffic_annotation) {}
54
~SOCKS5ClientSocket()55 SOCKS5ClientSocket::~SOCKS5ClientSocket() {
56 Disconnect();
57 }
58
Connect(CompletionOnceCallback callback)59 int SOCKS5ClientSocket::Connect(CompletionOnceCallback callback) {
60 DCHECK(transport_socket_);
61 DCHECK_EQ(STATE_NONE, next_state_);
62 DCHECK(user_callback_.is_null());
63
64 // If already connected, then just return OK.
65 if (completed_handshake_)
66 return OK;
67
68 net_log_.BeginEvent(NetLogEventType::SOCKS5_CONNECT);
69
70 next_state_ = STATE_GREET_WRITE;
71 buffer_.clear();
72
73 int rv = DoLoop(OK);
74 if (rv == ERR_IO_PENDING) {
75 user_callback_ = std::move(callback);
76 } else {
77 net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_CONNECT, rv);
78 }
79 return rv;
80 }
81
Disconnect()82 void SOCKS5ClientSocket::Disconnect() {
83 completed_handshake_ = false;
84 transport_socket_->Disconnect();
85
86 // Reset other states to make sure they aren't mistakenly used later.
87 // These are the states initialized by Connect().
88 next_state_ = STATE_NONE;
89 user_callback_.Reset();
90 }
91
IsConnected() const92 bool SOCKS5ClientSocket::IsConnected() const {
93 return completed_handshake_ && transport_socket_->IsConnected();
94 }
95
IsConnectedAndIdle() const96 bool SOCKS5ClientSocket::IsConnectedAndIdle() const {
97 return completed_handshake_ && transport_socket_->IsConnectedAndIdle();
98 }
99
NetLog() const100 const NetLogWithSource& SOCKS5ClientSocket::NetLog() const {
101 return net_log_;
102 }
103
WasEverUsed() const104 bool SOCKS5ClientSocket::WasEverUsed() const {
105 return was_ever_used_;
106 }
107
WasAlpnNegotiated() const108 bool SOCKS5ClientSocket::WasAlpnNegotiated() const {
109 if (transport_socket_)
110 return transport_socket_->WasAlpnNegotiated();
111 NOTREACHED();
112 return false;
113 }
114
GetNegotiatedProtocol() const115 NextProto SOCKS5ClientSocket::GetNegotiatedProtocol() const {
116 if (transport_socket_)
117 return transport_socket_->GetNegotiatedProtocol();
118 NOTREACHED();
119 return kProtoUnknown;
120 }
121
GetSSLInfo(SSLInfo * ssl_info)122 bool SOCKS5ClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
123 if (transport_socket_)
124 return transport_socket_->GetSSLInfo(ssl_info);
125 NOTREACHED();
126 return false;
127 }
128
GetConnectionAttempts(ConnectionAttempts * out) const129 void SOCKS5ClientSocket::GetConnectionAttempts(ConnectionAttempts* out) const {
130 out->clear();
131 }
132
GetTotalReceivedBytes() const133 int64_t SOCKS5ClientSocket::GetTotalReceivedBytes() const {
134 return transport_socket_->GetTotalReceivedBytes();
135 }
136
ApplySocketTag(const SocketTag & tag)137 void SOCKS5ClientSocket::ApplySocketTag(const SocketTag& tag) {
138 return transport_socket_->ApplySocketTag(tag);
139 }
140
141 // Read is called by the transport layer above to read. This can only be done
142 // if the SOCKS handshake is complete.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)143 int SOCKS5ClientSocket::Read(IOBuffer* buf,
144 int buf_len,
145 CompletionOnceCallback callback) {
146 DCHECK(completed_handshake_);
147 DCHECK_EQ(STATE_NONE, next_state_);
148 DCHECK(user_callback_.is_null());
149 DCHECK(!callback.is_null());
150
151 int rv = transport_socket_->Read(
152 buf, buf_len,
153 base::BindOnce(&SOCKS5ClientSocket::OnReadWriteComplete,
154 base::Unretained(this), std::move(callback)));
155 if (rv > 0)
156 was_ever_used_ = true;
157 return rv;
158 }
159
160 // Write is called by the transport layer. This can only be done if the
161 // SOCKS handshake is complete.
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)162 int SOCKS5ClientSocket::Write(
163 IOBuffer* buf,
164 int buf_len,
165 CompletionOnceCallback callback,
166 const NetworkTrafficAnnotationTag& traffic_annotation) {
167 DCHECK(completed_handshake_);
168 DCHECK_EQ(STATE_NONE, next_state_);
169 DCHECK(user_callback_.is_null());
170 DCHECK(!callback.is_null());
171
172 int rv = transport_socket_->Write(
173 buf, buf_len,
174 base::BindOnce(&SOCKS5ClientSocket::OnReadWriteComplete,
175 base::Unretained(this), std::move(callback)),
176 traffic_annotation);
177 if (rv > 0)
178 was_ever_used_ = true;
179 return rv;
180 }
181
SetReceiveBufferSize(int32_t size)182 int SOCKS5ClientSocket::SetReceiveBufferSize(int32_t size) {
183 return transport_socket_->SetReceiveBufferSize(size);
184 }
185
SetSendBufferSize(int32_t size)186 int SOCKS5ClientSocket::SetSendBufferSize(int32_t size) {
187 return transport_socket_->SetSendBufferSize(size);
188 }
189
DoCallback(int result)190 void SOCKS5ClientSocket::DoCallback(int result) {
191 DCHECK_NE(ERR_IO_PENDING, result);
192 DCHECK(!user_callback_.is_null());
193
194 // Since Run() may result in Read being called,
195 // clear user_callback_ up front.
196 std::move(user_callback_).Run(result);
197 }
198
OnIOComplete(int result)199 void SOCKS5ClientSocket::OnIOComplete(int result) {
200 DCHECK_NE(STATE_NONE, next_state_);
201 int rv = DoLoop(result);
202 if (rv != ERR_IO_PENDING) {
203 net_log_.EndEvent(NetLogEventType::SOCKS5_CONNECT);
204 DoCallback(rv);
205 }
206 }
207
OnReadWriteComplete(CompletionOnceCallback callback,int result)208 void SOCKS5ClientSocket::OnReadWriteComplete(CompletionOnceCallback callback,
209 int result) {
210 DCHECK_NE(ERR_IO_PENDING, result);
211 DCHECK(!callback.is_null());
212
213 if (result > 0)
214 was_ever_used_ = true;
215 std::move(callback).Run(result);
216 }
217
DoLoop(int last_io_result)218 int SOCKS5ClientSocket::DoLoop(int last_io_result) {
219 DCHECK_NE(next_state_, STATE_NONE);
220 int rv = last_io_result;
221 do {
222 State state = next_state_;
223 next_state_ = STATE_NONE;
224 switch (state) {
225 case STATE_GREET_WRITE:
226 DCHECK_EQ(OK, rv);
227 net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_WRITE);
228 rv = DoGreetWrite();
229 break;
230 case STATE_GREET_WRITE_COMPLETE:
231 rv = DoGreetWriteComplete(rv);
232 net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_WRITE,
233 rv);
234 break;
235 case STATE_GREET_READ:
236 DCHECK_EQ(OK, rv);
237 net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_READ);
238 rv = DoGreetRead();
239 break;
240 case STATE_GREET_READ_COMPLETE:
241 rv = DoGreetReadComplete(rv);
242 net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_READ,
243 rv);
244 break;
245 case STATE_HANDSHAKE_WRITE:
246 DCHECK_EQ(OK, rv);
247 net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_WRITE);
248 rv = DoHandshakeWrite();
249 break;
250 case STATE_HANDSHAKE_WRITE_COMPLETE:
251 rv = DoHandshakeWriteComplete(rv);
252 net_log_.EndEventWithNetErrorCode(
253 NetLogEventType::SOCKS5_HANDSHAKE_WRITE, rv);
254 break;
255 case STATE_HANDSHAKE_READ:
256 DCHECK_EQ(OK, rv);
257 net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_READ);
258 rv = DoHandshakeRead();
259 break;
260 case STATE_HANDSHAKE_READ_COMPLETE:
261 rv = DoHandshakeReadComplete(rv);
262 net_log_.EndEventWithNetErrorCode(
263 NetLogEventType::SOCKS5_HANDSHAKE_READ, rv);
264 break;
265 default:
266 NOTREACHED() << "bad state";
267 rv = ERR_UNEXPECTED;
268 break;
269 }
270 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
271 return rv;
272 }
273
274 const char kSOCKS5GreetWriteData[] = { 0x05, 0x01, 0x00 }; // no authentication
275
DoGreetWrite()276 int SOCKS5ClientSocket::DoGreetWrite() {
277 // Since we only have 1 byte to send the hostname length in, if the
278 // URL has a hostname longer than 255 characters we can't send it.
279 if (0xFF < destination_.host().size()) {
280 net_log_.AddEvent(NetLogEventType::SOCKS_HOSTNAME_TOO_BIG);
281 return ERR_SOCKS_CONNECTION_FAILED;
282 }
283
284 if (buffer_.empty()) {
285 buffer_ =
286 std::string(kSOCKS5GreetWriteData, base::size(kSOCKS5GreetWriteData));
287 bytes_sent_ = 0;
288 }
289
290 next_state_ = STATE_GREET_WRITE_COMPLETE;
291 size_t handshake_buf_len = buffer_.size() - bytes_sent_;
292 handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
293 memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_],
294 handshake_buf_len);
295 return transport_socket_->Write(handshake_buf_.get(), handshake_buf_len,
296 io_callback_, traffic_annotation_);
297 }
298
DoGreetWriteComplete(int result)299 int SOCKS5ClientSocket::DoGreetWriteComplete(int result) {
300 if (result < 0)
301 return result;
302
303 bytes_sent_ += result;
304 if (bytes_sent_ == buffer_.size()) {
305 buffer_.clear();
306 bytes_received_ = 0;
307 next_state_ = STATE_GREET_READ;
308 } else {
309 next_state_ = STATE_GREET_WRITE;
310 }
311 return OK;
312 }
313
DoGreetRead()314 int SOCKS5ClientSocket::DoGreetRead() {
315 next_state_ = STATE_GREET_READ_COMPLETE;
316 size_t handshake_buf_len = kGreetReadHeaderSize - bytes_received_;
317 handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
318 return transport_socket_->Read(handshake_buf_.get(), handshake_buf_len,
319 io_callback_);
320 }
321
DoGreetReadComplete(int result)322 int SOCKS5ClientSocket::DoGreetReadComplete(int result) {
323 if (result < 0)
324 return result;
325
326 if (result == 0) {
327 net_log_.AddEvent(
328 NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING);
329 return ERR_SOCKS_CONNECTION_FAILED;
330 }
331
332 bytes_received_ += result;
333 buffer_.append(handshake_buf_->data(), result);
334 if (bytes_received_ < kGreetReadHeaderSize) {
335 next_state_ = STATE_GREET_READ;
336 return OK;
337 }
338
339 // Got the greet data.
340 if (buffer_[0] != kSOCKS5Version) {
341 net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
342 "version", buffer_[0]);
343 return ERR_SOCKS_CONNECTION_FAILED;
344 }
345 if (buffer_[1] != 0x00) {
346 net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_AUTH,
347 "method", buffer_[1]);
348 return ERR_SOCKS_CONNECTION_FAILED;
349 }
350
351 buffer_.clear();
352 next_state_ = STATE_HANDSHAKE_WRITE;
353 return OK;
354 }
355
BuildHandshakeWriteBuffer(std::string * handshake) const356 int SOCKS5ClientSocket::BuildHandshakeWriteBuffer(std::string* handshake)
357 const {
358 DCHECK(handshake->empty());
359
360 handshake->push_back(kSOCKS5Version);
361 handshake->push_back(kTunnelCommand); // Connect command
362 handshake->push_back(kNullByte); // Reserved null
363
364 handshake->push_back(kEndPointDomain); // The type of the address.
365
366 DCHECK_GE(static_cast<size_t>(0xFF), destination_.host().size());
367
368 // First add the size of the hostname, followed by the hostname.
369 handshake->push_back(static_cast<unsigned char>(destination_.host().size()));
370 handshake->append(destination_.host());
371
372 uint16_t nw_port = base::HostToNet16(destination_.port());
373 handshake->append(reinterpret_cast<char*>(&nw_port), sizeof(nw_port));
374 return OK;
375 }
376
377 // Writes the SOCKS handshake data to the underlying socket connection.
DoHandshakeWrite()378 int SOCKS5ClientSocket::DoHandshakeWrite() {
379 next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
380
381 if (buffer_.empty()) {
382 int rv = BuildHandshakeWriteBuffer(&buffer_);
383 if (rv != OK)
384 return rv;
385 bytes_sent_ = 0;
386 }
387
388 int handshake_buf_len = buffer_.size() - bytes_sent_;
389 DCHECK_LT(0, handshake_buf_len);
390 handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
391 memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
392 handshake_buf_len);
393 return transport_socket_->Write(handshake_buf_.get(), handshake_buf_len,
394 io_callback_, traffic_annotation_);
395 }
396
DoHandshakeWriteComplete(int result)397 int SOCKS5ClientSocket::DoHandshakeWriteComplete(int result) {
398 if (result < 0)
399 return result;
400
401 // We ignore the case when result is 0, since the underlying Write
402 // may return spurious writes while waiting on the socket.
403
404 bytes_sent_ += result;
405 if (bytes_sent_ == buffer_.size()) {
406 next_state_ = STATE_HANDSHAKE_READ;
407 buffer_.clear();
408 } else if (bytes_sent_ < buffer_.size()) {
409 next_state_ = STATE_HANDSHAKE_WRITE;
410 } else {
411 NOTREACHED();
412 }
413
414 return OK;
415 }
416
DoHandshakeRead()417 int SOCKS5ClientSocket::DoHandshakeRead() {
418 next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
419
420 if (buffer_.empty()) {
421 bytes_received_ = 0;
422 read_header_size = kReadHeaderSize;
423 }
424
425 int handshake_buf_len = read_header_size - bytes_received_;
426 handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
427 return transport_socket_->Read(handshake_buf_.get(), handshake_buf_len,
428 io_callback_);
429 }
430
DoHandshakeReadComplete(int result)431 int SOCKS5ClientSocket::DoHandshakeReadComplete(int result) {
432 if (result < 0)
433 return result;
434
435 // The underlying socket closed unexpectedly.
436 if (result == 0) {
437 net_log_.AddEvent(
438 NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE);
439 return ERR_SOCKS_CONNECTION_FAILED;
440 }
441
442 buffer_.append(handshake_buf_->data(), result);
443 bytes_received_ += result;
444
445 // When the first few bytes are read, check how many more are required
446 // and accordingly increase them
447 if (bytes_received_ == kReadHeaderSize) {
448 if (buffer_[0] != kSOCKS5Version || buffer_[2] != kNullByte) {
449 net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
450 "version", buffer_[0]);
451 return ERR_SOCKS_CONNECTION_FAILED;
452 }
453 if (buffer_[1] != 0x00) {
454 net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_SERVER_ERROR,
455 "error_code", buffer_[1]);
456 return ERR_SOCKS_CONNECTION_FAILED;
457 }
458
459 // We check the type of IP/Domain the server returns and accordingly
460 // increase the size of the response. For domains, we need to read the
461 // size of the domain, so the initial request size is upto the domain
462 // size. Since for IPv4/IPv6 the size is fixed and hence no 'size' is
463 // read, we substract 1 byte from the additional request size.
464 SocksEndPointAddressType address_type =
465 static_cast<SocksEndPointAddressType>(buffer_[3]);
466 if (address_type == kEndPointDomain) {
467 read_header_size += static_cast<uint8_t>(buffer_[4]);
468 } else if (address_type == kEndPointResolvedIPv4) {
469 read_header_size += sizeof(struct in_addr) - 1;
470 } else if (address_type == kEndPointResolvedIPv6) {
471 read_header_size += sizeof(struct in6_addr) - 1;
472 } else {
473 net_log_.AddEventWithIntParams(
474 NetLogEventType::SOCKS_UNKNOWN_ADDRESS_TYPE, "address_type",
475 buffer_[3]);
476 return ERR_SOCKS_CONNECTION_FAILED;
477 }
478
479 read_header_size += 2; // for the port.
480 next_state_ = STATE_HANDSHAKE_READ;
481 return OK;
482 }
483
484 // When the final bytes are read, setup handshake. We ignore the rest
485 // of the response since they represent the SOCKSv5 endpoint and have
486 // no use when doing a tunnel connection.
487 if (bytes_received_ == read_header_size) {
488 completed_handshake_ = true;
489 buffer_.clear();
490 next_state_ = STATE_NONE;
491 return OK;
492 }
493
494 next_state_ = STATE_HANDSHAKE_READ;
495 return OK;
496 }
497
GetPeerAddress(IPEndPoint * address) const498 int SOCKS5ClientSocket::GetPeerAddress(IPEndPoint* address) const {
499 return transport_socket_->GetPeerAddress(address);
500 }
501
GetLocalAddress(IPEndPoint * address) const502 int SOCKS5ClientSocket::GetLocalAddress(IPEndPoint* address) const {
503 return transport_socket_->GetLocalAddress(address);
504 }
505
506 } // namespace net
507