1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "platform/impl/stream_socket_posix.h"
6
7 #include <fcntl.h>
8 #include <netinet/in.h>
9 #include <netinet/ip.h>
10 #include <string.h>
11 #include <sys/socket.h>
12 #include <sys/types.h>
13 #include <unistd.h>
14
15 namespace openscreen {
16
17 namespace {
18 constexpr int kDefaultMaxBacklogSize = 64;
19
20 // Call Select with no timeout, so that it doesn't block. Then use the result
21 // to determine if any connection is pending.
IsConnectionPending(int fd)22 bool IsConnectionPending(int fd) {
23 fd_set handle_set;
24 FD_ZERO(&handle_set);
25 FD_SET(fd, &handle_set);
26 struct timeval tv {
27 0
28 };
29 return select(fd + 1, &handle_set, nullptr, nullptr, &tv) > 0;
30 }
31 } // namespace
32
StreamSocketPosix(IPAddress::Version version)33 StreamSocketPosix::StreamSocketPosix(IPAddress::Version version)
34 : version_(version) {}
35
StreamSocketPosix(const IPEndpoint & local_endpoint)36 StreamSocketPosix::StreamSocketPosix(const IPEndpoint& local_endpoint)
37 : version_(local_endpoint.address.version()),
38 local_address_(local_endpoint) {}
39
StreamSocketPosix(SocketAddressPosix local_address,IPEndpoint remote_address,int file_descriptor)40 StreamSocketPosix::StreamSocketPosix(SocketAddressPosix local_address,
41 IPEndpoint remote_address,
42 int file_descriptor)
43 : handle_(file_descriptor),
44 version_(local_address.version()),
45 local_address_(local_address),
46 remote_address_(remote_address),
47 state_(SocketState::kConnected) {
48 EnsureInitialized();
49 }
50
~StreamSocketPosix()51 StreamSocketPosix::~StreamSocketPosix() {
52 if (state_ == SocketState::kConnected) {
53 Close();
54 }
55 }
56
GetWeakPtr() const57 WeakPtr<StreamSocketPosix> StreamSocketPosix::GetWeakPtr() const {
58 return weak_factory_.GetWeakPtr();
59 }
60
Accept()61 ErrorOr<std::unique_ptr<StreamSocket>> StreamSocketPosix::Accept() {
62 if (!EnsureInitialized()) {
63 return ReportSocketClosedError();
64 }
65
66 if (!is_bound_) {
67 return CloseOnError(Error::Code::kSocketInvalidState);
68 }
69
70 // Check if any connection is pending, and return a special error code if not.
71 if (!IsConnectionPending(handle_.fd)) {
72 return Error::Code::kAgain;
73 }
74
75 // We copy our address to new_remote_address since it should be in the same
76 // family. The accept call will overwrite it.
77 SocketAddressPosix new_remote_address = local_address_.value();
78 socklen_t remote_address_size = new_remote_address.size();
79 const int new_file_descriptor =
80 accept(handle_.fd, new_remote_address.address(), &remote_address_size);
81 if (new_file_descriptor == kUnsetHandleFd) {
82 return CloseOnError(
83 Error(Error::Code::kSocketAcceptFailure, strerror(errno)));
84 }
85 new_remote_address.RecomputeEndpoint();
86
87 return ErrorOr<std::unique_ptr<StreamSocket>>(
88 std::make_unique<StreamSocketPosix>(local_address_.value(),
89 new_remote_address.endpoint(),
90 new_file_descriptor));
91 }
92
Bind()93 Error StreamSocketPosix::Bind() {
94 if (!local_address_.has_value()) {
95 return CloseOnError(Error::Code::kSocketInvalidState);
96 }
97
98 if (!EnsureInitialized()) {
99 return ReportSocketClosedError();
100 }
101
102 if (is_bound_) {
103 return CloseOnError(Error::Code::kSocketInvalidState);
104 }
105
106 if (bind(handle_.fd, local_address_.value().address(),
107 local_address_.value().size()) != 0) {
108 return CloseOnError(
109 Error(Error::Code::kSocketBindFailure, strerror(errno)));
110 }
111
112 is_bound_ = true;
113 return Error::None();
114 }
115
Close()116 Error StreamSocketPosix::Close() {
117 if (!EnsureInitialized()) {
118 return ReportSocketClosedError();
119 }
120
121 if (state_ == SocketState::kClosed) {
122 last_error_code_ = Error::Code::kSocketInvalidState;
123 return Error::Code::kSocketInvalidState;
124 }
125
126 const int file_descriptor_to_close = handle_.fd;
127 if (close(file_descriptor_to_close) != 0) {
128 last_error_code_ = Error::Code::kSocketInvalidState;
129 return Error::Code::kSocketInvalidState;
130 }
131 handle_.fd = kUnsetHandleFd;
132
133 return Error::None();
134 }
135
Connect(const IPEndpoint & remote_endpoint)136 Error StreamSocketPosix::Connect(const IPEndpoint& remote_endpoint) {
137 if (!EnsureInitialized()) {
138 return ReportSocketClosedError();
139 }
140
141 if (!is_initialized_ && !is_bound_) {
142 return CloseOnError(Error::Code::kSocketInvalidState);
143 }
144
145 SocketAddressPosix address(remote_endpoint);
146 int ret = connect(handle_.fd, address.address(), address.size());
147 if (ret != 0 && errno != EINPROGRESS) {
148 return CloseOnError(
149 Error(Error::Code::kSocketConnectFailure, strerror(errno)));
150 }
151
152 if (!is_bound_) {
153 if (local_address_.has_value()) {
154 return CloseOnError(Error::Code::kSocketInvalidState);
155 }
156
157 struct sockaddr_in6 address;
158 socklen_t size = sizeof(address);
159 if (getsockname(handle_.fd, reinterpret_cast<struct sockaddr*>(&address),
160 &size) != 0) {
161 return CloseOnError(Error::Code::kSocketConnectFailure);
162 }
163
164 local_address_.emplace(reinterpret_cast<struct sockaddr&>(address));
165 is_bound_ = true;
166 }
167
168 remote_address_ = remote_endpoint;
169 state_ = SocketState::kConnected;
170 return Error::None();
171 }
172
Listen()173 Error StreamSocketPosix::Listen() {
174 return Listen(kDefaultMaxBacklogSize);
175 }
176
Listen(int max_backlog_size)177 Error StreamSocketPosix::Listen(int max_backlog_size) {
178 if (!EnsureInitialized()) {
179 return ReportSocketClosedError();
180 }
181
182 if (listen(handle_.fd, max_backlog_size) != 0) {
183 return CloseOnError(
184 Error(Error::Code::kSocketListenFailure, strerror(errno)));
185 }
186
187 return Error::None();
188 }
189
remote_address() const190 absl::optional<IPEndpoint> StreamSocketPosix::remote_address() const {
191 if ((state_ != SocketState::kConnected) || !remote_address_) {
192 return absl::nullopt;
193 }
194 return remote_address_.value();
195 }
196
local_address() const197 absl::optional<IPEndpoint> StreamSocketPosix::local_address() const {
198 if (!local_address_.has_value()) {
199 return absl::nullopt;
200 }
201 return local_address_.value().endpoint();
202 }
203
state() const204 SocketState StreamSocketPosix::state() const {
205 return state_;
206 }
207
version() const208 IPAddress::Version StreamSocketPosix::version() const {
209 return version_;
210 }
211
EnsureInitialized()212 bool StreamSocketPosix::EnsureInitialized() {
213 if (!is_initialized_ && (last_error_code_ == Error::Code::kNone)) {
214 return Initialize() == Error::None();
215 }
216
217 return handle_.fd != kUnsetHandleFd && is_initialized_;
218 }
219
Initialize()220 Error StreamSocketPosix::Initialize() {
221 if (is_initialized_) {
222 return Error::Code::kItemAlreadyExists;
223 }
224
225 int fd = handle_.fd;
226 if (fd == kUnsetHandleFd) {
227 int domain;
228 switch (version_) {
229 case IPAddress::Version::kV4:
230 domain = AF_INET;
231 break;
232 case IPAddress::Version::kV6:
233 domain = AF_INET6;
234 break;
235 }
236
237 fd = socket(domain, SOCK_STREAM, 0);
238 if (fd == kUnsetHandleFd) {
239 last_error_code_ = Error::Code::kSocketInvalidState;
240 return Error::Code::kSocketInvalidState;
241 }
242 }
243
244 const int current_flags = fcntl(fd, F_GETFL, 0);
245 if (fcntl(fd, F_SETFL, current_flags | O_NONBLOCK) == -1) {
246 close(fd);
247 last_error_code_ = Error::Code::kSocketInvalidState;
248 return Error::Code::kSocketInvalidState;
249 }
250
251 handle_.fd = fd;
252 is_initialized_ = true;
253 // last_error_code_ should still be Error::None().
254 return Error::None();
255 }
256
CloseOnError(Error error)257 Error StreamSocketPosix::CloseOnError(Error error) {
258 last_error_code_ = error.code();
259 Close();
260 state_ = SocketState::kClosed;
261 return error;
262 }
263
264 // If is_open is false, the socket has either not been initialized
265 // or has been closed, either on purpose or due to error.
ReportSocketClosedError()266 Error StreamSocketPosix::ReportSocketClosedError() {
267 last_error_code_ = Error::Code::kSocketClosedFailure;
268 return Error::Code::kSocketClosedFailure;
269 }
270 } // namespace openscreen
271