1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/unix_domain_server_socket_posix.h"
6
7 #include <errno.h>
8 #include <sys/socket.h>
9 #include <sys/un.h>
10 #include <unistd.h>
11 #include <utility>
12
13 #include "base/bind.h"
14 #include "base/logging.h"
15 #include "build/build_config.h"
16 #include "net/base/net_errors.h"
17 #include "net/base/sockaddr_storage.h"
18 #include "net/socket/socket_posix.h"
19 #include "net/socket/unix_domain_client_socket_posix.h"
20
21 namespace net {
22
UnixDomainServerSocket(const AuthCallback & auth_callback,bool use_abstract_namespace)23 UnixDomainServerSocket::UnixDomainServerSocket(
24 const AuthCallback& auth_callback,
25 bool use_abstract_namespace)
26 : auth_callback_(auth_callback),
27 use_abstract_namespace_(use_abstract_namespace) {
28 DCHECK(!auth_callback_.is_null());
29 }
30
31 UnixDomainServerSocket::~UnixDomainServerSocket() = default;
32
33 // static
GetPeerCredentials(SocketDescriptor socket,Credentials * credentials)34 bool UnixDomainServerSocket::GetPeerCredentials(SocketDescriptor socket,
35 Credentials* credentials) {
36 #if defined(OS_LINUX) || defined(OS_CHROMEOS) || defined(OS_ANDROID) || \
37 defined(OS_FUCHSIA)
38 struct ucred user_cred;
39 socklen_t len = sizeof(user_cred);
40 if (getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &user_cred, &len) < 0)
41 return false;
42 credentials->process_id = user_cred.pid;
43 credentials->user_id = user_cred.uid;
44 credentials->group_id = user_cred.gid;
45 return true;
46 #else
47 return getpeereid(
48 socket, &credentials->user_id, &credentials->group_id) == 0;
49 #endif
50 }
51
Listen(const IPEndPoint & address,int backlog)52 int UnixDomainServerSocket::Listen(const IPEndPoint& address, int backlog) {
53 NOTIMPLEMENTED();
54 return ERR_NOT_IMPLEMENTED;
55 }
56
ListenWithAddressAndPort(const std::string & address_string,uint16_t port,int backlog)57 int UnixDomainServerSocket::ListenWithAddressAndPort(
58 const std::string& address_string,
59 uint16_t port,
60 int backlog) {
61 NOTIMPLEMENTED();
62 return ERR_NOT_IMPLEMENTED;
63 }
64
BindAndListen(const std::string & socket_path,int backlog)65 int UnixDomainServerSocket::BindAndListen(const std::string& socket_path,
66 int backlog) {
67 DCHECK(!listen_socket_);
68
69 SockaddrStorage address;
70 if (!UnixDomainClientSocket::FillAddress(socket_path,
71 use_abstract_namespace_,
72 &address)) {
73 return ERR_ADDRESS_INVALID;
74 }
75
76 std::unique_ptr<SocketPosix> socket(new SocketPosix);
77 int rv = socket->Open(AF_UNIX);
78 DCHECK_NE(ERR_IO_PENDING, rv);
79 if (rv != OK)
80 return rv;
81
82 rv = socket->Bind(address);
83 DCHECK_NE(ERR_IO_PENDING, rv);
84 if (rv != OK) {
85 PLOG(ERROR)
86 << "Could not bind unix domain socket to " << socket_path
87 << (use_abstract_namespace_ ? " (with abstract namespace)" : "");
88 return rv;
89 }
90
91 rv = socket->Listen(backlog);
92 DCHECK_NE(ERR_IO_PENDING, rv);
93 if (rv != OK)
94 return rv;
95
96 listen_socket_.swap(socket);
97 return rv;
98 }
99
GetLocalAddress(IPEndPoint * address) const100 int UnixDomainServerSocket::GetLocalAddress(IPEndPoint* address) const {
101 DCHECK(address);
102
103 // Unix domain sockets have no valid associated addr/port;
104 // return address invalid.
105 return ERR_ADDRESS_INVALID;
106 }
107
Accept(std::unique_ptr<StreamSocket> * socket,CompletionOnceCallback callback)108 int UnixDomainServerSocket::Accept(std::unique_ptr<StreamSocket>* socket,
109 CompletionOnceCallback callback) {
110 DCHECK(socket);
111 DCHECK(callback);
112 DCHECK(!callback_ && !out_socket_.stream && !out_socket_.descriptor);
113
114 out_socket_ = {socket, nullptr};
115 int rv = DoAccept();
116 if (rv == ERR_IO_PENDING)
117 callback_ = std::move(callback);
118 else
119 CancelCallback();
120 return rv;
121 }
122
AcceptSocketDescriptor(SocketDescriptor * socket,CompletionOnceCallback callback)123 int UnixDomainServerSocket::AcceptSocketDescriptor(
124 SocketDescriptor* socket,
125 CompletionOnceCallback callback) {
126 DCHECK(socket);
127 DCHECK(callback);
128 DCHECK(!callback_ && !out_socket_.stream && !out_socket_.descriptor);
129
130 out_socket_ = {nullptr, socket};
131 int rv = DoAccept();
132 if (rv == ERR_IO_PENDING)
133 callback_ = std::move(callback);
134 else
135 CancelCallback();
136 return rv;
137 }
138
DoAccept()139 int UnixDomainServerSocket::DoAccept() {
140 DCHECK(listen_socket_);
141 DCHECK(!accept_socket_);
142
143 while (true) {
144 int rv = listen_socket_->Accept(
145 &accept_socket_,
146 base::BindOnce(&UnixDomainServerSocket::AcceptCompleted,
147 base::Unretained(this)));
148 if (rv != OK)
149 return rv;
150 if (AuthenticateAndGetStreamSocket())
151 return OK;
152 // Accept another socket because authentication error should be transparent
153 // to the caller.
154 }
155 }
156
AcceptCompleted(int rv)157 void UnixDomainServerSocket::AcceptCompleted(int rv) {
158 DCHECK(!callback_.is_null());
159
160 if (rv != OK) {
161 RunCallback(rv);
162 return;
163 }
164
165 if (AuthenticateAndGetStreamSocket()) {
166 RunCallback(OK);
167 return;
168 }
169
170 // Accept another socket because authentication error should be transparent
171 // to the caller.
172 rv = DoAccept();
173 if (rv != ERR_IO_PENDING)
174 RunCallback(rv);
175 }
176
AuthenticateAndGetStreamSocket()177 bool UnixDomainServerSocket::AuthenticateAndGetStreamSocket() {
178 DCHECK(accept_socket_);
179
180 Credentials credentials;
181 if (!GetPeerCredentials(accept_socket_->socket_fd(), &credentials) ||
182 !auth_callback_.Run(credentials)) {
183 accept_socket_.reset();
184 return false;
185 }
186
187 SetSocketResult(std::move(accept_socket_));
188 return true;
189 }
190
SetSocketResult(std::unique_ptr<SocketPosix> accepted_socket)191 void UnixDomainServerSocket::SetSocketResult(
192 std::unique_ptr<SocketPosix> accepted_socket) {
193 // Exactly one of the output pointers should be set.
194 DCHECK_NE(!!out_socket_.stream, !!out_socket_.descriptor);
195
196 // Pass ownership of |accepted_socket|.
197 if (out_socket_.descriptor) {
198 *out_socket_.descriptor = accepted_socket->ReleaseConnectedSocket();
199 return;
200 }
201 *out_socket_.stream =
202 std::make_unique<UnixDomainClientSocket>(std::move(accepted_socket));
203 }
204
RunCallback(int rv)205 void UnixDomainServerSocket::RunCallback(int rv) {
206 out_socket_ = SocketDestination();
207 std::move(callback_).Run(rv);
208 }
209
CancelCallback()210 void UnixDomainServerSocket::CancelCallback() {
211 out_socket_ = SocketDestination();
212 callback_.Reset();
213 }
214
215 } // namespace net
216