1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "platform/impl/stream_socket_posix.h"
6 
7 #include <fcntl.h>
8 #include <netinet/in.h>
9 #include <netinet/ip.h>
10 #include <string.h>
11 #include <sys/socket.h>
12 #include <sys/types.h>
13 #include <unistd.h>
14 
15 namespace openscreen {
16 
17 namespace {
18 constexpr int kDefaultMaxBacklogSize = 64;
19 
20 // Call Select with no timeout, so that it doesn't block. Then use the result
21 // to determine if any connection is pending.
IsConnectionPending(int fd)22 bool IsConnectionPending(int fd) {
23   fd_set handle_set;
24   FD_ZERO(&handle_set);
25   FD_SET(fd, &handle_set);
26   struct timeval tv {
27     0
28   };
29   return select(fd + 1, &handle_set, nullptr, nullptr, &tv) > 0;
30 }
31 }  // namespace
32 
StreamSocketPosix(IPAddress::Version version)33 StreamSocketPosix::StreamSocketPosix(IPAddress::Version version)
34     : version_(version) {}
35 
StreamSocketPosix(const IPEndpoint & local_endpoint)36 StreamSocketPosix::StreamSocketPosix(const IPEndpoint& local_endpoint)
37     : version_(local_endpoint.address.version()),
38       local_address_(local_endpoint) {}
39 
StreamSocketPosix(SocketAddressPosix local_address,IPEndpoint remote_address,int file_descriptor)40 StreamSocketPosix::StreamSocketPosix(SocketAddressPosix local_address,
41                                      IPEndpoint remote_address,
42                                      int file_descriptor)
43     : handle_(file_descriptor),
44       version_(local_address.version()),
45       local_address_(local_address),
46       remote_address_(remote_address),
47       state_(SocketState::kConnected) {
48   EnsureInitialized();
49 }
50 
~StreamSocketPosix()51 StreamSocketPosix::~StreamSocketPosix() {
52   if (state_ == SocketState::kConnected) {
53     Close();
54   }
55 }
56 
GetWeakPtr() const57 WeakPtr<StreamSocketPosix> StreamSocketPosix::GetWeakPtr() const {
58   return weak_factory_.GetWeakPtr();
59 }
60 
Accept()61 ErrorOr<std::unique_ptr<StreamSocket>> StreamSocketPosix::Accept() {
62   if (!EnsureInitialized()) {
63     return ReportSocketClosedError();
64   }
65 
66   if (!is_bound_) {
67     return CloseOnError(Error::Code::kSocketInvalidState);
68   }
69 
70   // Check if any connection is pending, and return a special error code if not.
71   if (!IsConnectionPending(handle_.fd)) {
72     return Error::Code::kAgain;
73   }
74 
75   // We copy our address to new_remote_address since it should be in the same
76   // family. The accept call will overwrite it.
77   SocketAddressPosix new_remote_address = local_address_.value();
78   socklen_t remote_address_size = new_remote_address.size();
79   const int new_file_descriptor =
80       accept(handle_.fd, new_remote_address.address(), &remote_address_size);
81   if (new_file_descriptor == kUnsetHandleFd) {
82     return CloseOnError(
83         Error(Error::Code::kSocketAcceptFailure, strerror(errno)));
84   }
85   new_remote_address.RecomputeEndpoint();
86 
87   return ErrorOr<std::unique_ptr<StreamSocket>>(
88       std::make_unique<StreamSocketPosix>(local_address_.value(),
89                                           new_remote_address.endpoint(),
90                                           new_file_descriptor));
91 }
92 
Bind()93 Error StreamSocketPosix::Bind() {
94   if (!local_address_.has_value()) {
95     return CloseOnError(Error::Code::kSocketInvalidState);
96   }
97 
98   if (!EnsureInitialized()) {
99     return ReportSocketClosedError();
100   }
101 
102   if (is_bound_) {
103     return CloseOnError(Error::Code::kSocketInvalidState);
104   }
105 
106   if (bind(handle_.fd, local_address_.value().address(),
107            local_address_.value().size()) != 0) {
108     return CloseOnError(
109         Error(Error::Code::kSocketBindFailure, strerror(errno)));
110   }
111 
112   is_bound_ = true;
113   return Error::None();
114 }
115 
Close()116 Error StreamSocketPosix::Close() {
117   if (!EnsureInitialized()) {
118     return ReportSocketClosedError();
119   }
120 
121   if (state_ == SocketState::kClosed) {
122     last_error_code_ = Error::Code::kSocketInvalidState;
123     return Error::Code::kSocketInvalidState;
124   }
125 
126   const int file_descriptor_to_close = handle_.fd;
127   if (close(file_descriptor_to_close) != 0) {
128     last_error_code_ = Error::Code::kSocketInvalidState;
129     return Error::Code::kSocketInvalidState;
130   }
131   handle_.fd = kUnsetHandleFd;
132 
133   return Error::None();
134 }
135 
Connect(const IPEndpoint & remote_endpoint)136 Error StreamSocketPosix::Connect(const IPEndpoint& remote_endpoint) {
137   if (!EnsureInitialized()) {
138     return ReportSocketClosedError();
139   }
140 
141   if (!is_initialized_ && !is_bound_) {
142     return CloseOnError(Error::Code::kSocketInvalidState);
143   }
144 
145   SocketAddressPosix address(remote_endpoint);
146   int ret = connect(handle_.fd, address.address(), address.size());
147   if (ret != 0 && errno != EINPROGRESS) {
148     return CloseOnError(
149         Error(Error::Code::kSocketConnectFailure, strerror(errno)));
150   }
151 
152   if (!is_bound_) {
153     if (local_address_.has_value()) {
154       return CloseOnError(Error::Code::kSocketInvalidState);
155     }
156 
157     struct sockaddr_in6 address;
158     socklen_t size = sizeof(address);
159     if (getsockname(handle_.fd, reinterpret_cast<struct sockaddr*>(&address),
160                     &size) != 0) {
161       return CloseOnError(Error::Code::kSocketConnectFailure);
162     }
163 
164     local_address_.emplace(reinterpret_cast<struct sockaddr&>(address));
165     is_bound_ = true;
166   }
167 
168   remote_address_ = remote_endpoint;
169   state_ = SocketState::kConnected;
170   return Error::None();
171 }
172 
Listen()173 Error StreamSocketPosix::Listen() {
174   return Listen(kDefaultMaxBacklogSize);
175 }
176 
Listen(int max_backlog_size)177 Error StreamSocketPosix::Listen(int max_backlog_size) {
178   if (!EnsureInitialized()) {
179     return ReportSocketClosedError();
180   }
181 
182   if (listen(handle_.fd, max_backlog_size) != 0) {
183     return CloseOnError(
184         Error(Error::Code::kSocketListenFailure, strerror(errno)));
185   }
186 
187   return Error::None();
188 }
189 
remote_address() const190 absl::optional<IPEndpoint> StreamSocketPosix::remote_address() const {
191   if ((state_ != SocketState::kConnected) || !remote_address_) {
192     return absl::nullopt;
193   }
194   return remote_address_.value();
195 }
196 
local_address() const197 absl::optional<IPEndpoint> StreamSocketPosix::local_address() const {
198   if (!local_address_.has_value()) {
199     return absl::nullopt;
200   }
201   return local_address_.value().endpoint();
202 }
203 
state() const204 SocketState StreamSocketPosix::state() const {
205   return state_;
206 }
207 
version() const208 IPAddress::Version StreamSocketPosix::version() const {
209   return version_;
210 }
211 
EnsureInitialized()212 bool StreamSocketPosix::EnsureInitialized() {
213   if (!is_initialized_ && (last_error_code_ == Error::Code::kNone)) {
214     return Initialize() == Error::None();
215   }
216 
217   return handle_.fd != kUnsetHandleFd && is_initialized_;
218 }
219 
Initialize()220 Error StreamSocketPosix::Initialize() {
221   if (is_initialized_) {
222     return Error::Code::kItemAlreadyExists;
223   }
224 
225   int fd = handle_.fd;
226   if (fd == kUnsetHandleFd) {
227     int domain;
228     switch (version_) {
229       case IPAddress::Version::kV4:
230         domain = AF_INET;
231         break;
232       case IPAddress::Version::kV6:
233         domain = AF_INET6;
234         break;
235     }
236 
237     fd = socket(domain, SOCK_STREAM, 0);
238     if (fd == kUnsetHandleFd) {
239       last_error_code_ = Error::Code::kSocketInvalidState;
240       return Error::Code::kSocketInvalidState;
241     }
242   }
243 
244   const int current_flags = fcntl(fd, F_GETFL, 0);
245   if (fcntl(fd, F_SETFL, current_flags | O_NONBLOCK) == -1) {
246     close(fd);
247     last_error_code_ = Error::Code::kSocketInvalidState;
248     return Error::Code::kSocketInvalidState;
249   }
250 
251   handle_.fd = fd;
252   is_initialized_ = true;
253   // last_error_code_ should still be Error::None().
254   return Error::None();
255 }
256 
CloseOnError(Error error)257 Error StreamSocketPosix::CloseOnError(Error error) {
258   last_error_code_ = error.code();
259   Close();
260   state_ = SocketState::kClosed;
261   return error;
262 }
263 
264 // If is_open is false, the socket has either not been initialized
265 // or has been closed, either on purpose or due to error.
ReportSocketClosedError()266 Error StreamSocketPosix::ReportSocketClosedError() {
267   last_error_code_ = Error::Code::kSocketClosedFailure;
268   return Error::Code::kSocketClosedFailure;
269 }
270 }  // namespace openscreen
271