1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/unix_domain_server_socket_posix.h"
6 
7 #include <errno.h>
8 #include <sys/socket.h>
9 #include <sys/un.h>
10 #include <unistd.h>
11 #include <utility>
12 
13 #include "base/bind.h"
14 #include "base/logging.h"
15 #include "build/build_config.h"
16 #include "net/base/net_errors.h"
17 #include "net/base/sockaddr_storage.h"
18 #include "net/socket/socket_posix.h"
19 #include "net/socket/unix_domain_client_socket_posix.h"
20 
21 namespace net {
22 
UnixDomainServerSocket(const AuthCallback & auth_callback,bool use_abstract_namespace)23 UnixDomainServerSocket::UnixDomainServerSocket(
24     const AuthCallback& auth_callback,
25     bool use_abstract_namespace)
26     : auth_callback_(auth_callback),
27       use_abstract_namespace_(use_abstract_namespace) {
28   DCHECK(!auth_callback_.is_null());
29 }
30 
31 UnixDomainServerSocket::~UnixDomainServerSocket() = default;
32 
33 // static
GetPeerCredentials(SocketDescriptor socket,Credentials * credentials)34 bool UnixDomainServerSocket::GetPeerCredentials(SocketDescriptor socket,
35                                                 Credentials* credentials) {
36 #if defined(OS_LINUX) || defined(OS_CHROMEOS) || defined(OS_ANDROID) || \
37     defined(OS_FUCHSIA)
38   struct ucred user_cred;
39   socklen_t len = sizeof(user_cred);
40   if (getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &user_cred, &len) < 0)
41     return false;
42   credentials->process_id = user_cred.pid;
43   credentials->user_id = user_cred.uid;
44   credentials->group_id = user_cred.gid;
45   return true;
46 #else
47   return getpeereid(
48       socket, &credentials->user_id, &credentials->group_id) == 0;
49 #endif
50 }
51 
Listen(const IPEndPoint & address,int backlog)52 int UnixDomainServerSocket::Listen(const IPEndPoint& address, int backlog) {
53   NOTIMPLEMENTED();
54   return ERR_NOT_IMPLEMENTED;
55 }
56 
ListenWithAddressAndPort(const std::string & address_string,uint16_t port,int backlog)57 int UnixDomainServerSocket::ListenWithAddressAndPort(
58     const std::string& address_string,
59     uint16_t port,
60     int backlog) {
61   NOTIMPLEMENTED();
62   return ERR_NOT_IMPLEMENTED;
63 }
64 
BindAndListen(const std::string & socket_path,int backlog)65 int UnixDomainServerSocket::BindAndListen(const std::string& socket_path,
66                                           int backlog) {
67   DCHECK(!listen_socket_);
68 
69   SockaddrStorage address;
70   if (!UnixDomainClientSocket::FillAddress(socket_path,
71                                            use_abstract_namespace_,
72                                            &address)) {
73     return ERR_ADDRESS_INVALID;
74   }
75 
76   std::unique_ptr<SocketPosix> socket(new SocketPosix);
77   int rv = socket->Open(AF_UNIX);
78   DCHECK_NE(ERR_IO_PENDING, rv);
79   if (rv != OK)
80     return rv;
81 
82   rv = socket->Bind(address);
83   DCHECK_NE(ERR_IO_PENDING, rv);
84   if (rv != OK) {
85     PLOG(ERROR)
86         << "Could not bind unix domain socket to " << socket_path
87         << (use_abstract_namespace_ ? " (with abstract namespace)" : "");
88     return rv;
89   }
90 
91   rv = socket->Listen(backlog);
92   DCHECK_NE(ERR_IO_PENDING, rv);
93   if (rv != OK)
94     return rv;
95 
96   listen_socket_.swap(socket);
97   return rv;
98 }
99 
GetLocalAddress(IPEndPoint * address) const100 int UnixDomainServerSocket::GetLocalAddress(IPEndPoint* address) const {
101   DCHECK(address);
102 
103   // Unix domain sockets have no valid associated addr/port;
104   // return address invalid.
105   return ERR_ADDRESS_INVALID;
106 }
107 
Accept(std::unique_ptr<StreamSocket> * socket,CompletionOnceCallback callback)108 int UnixDomainServerSocket::Accept(std::unique_ptr<StreamSocket>* socket,
109                                    CompletionOnceCallback callback) {
110   DCHECK(socket);
111   DCHECK(callback);
112   DCHECK(!callback_ && !out_socket_.stream && !out_socket_.descriptor);
113 
114   out_socket_ = {socket, nullptr};
115   int rv = DoAccept();
116   if (rv == ERR_IO_PENDING)
117     callback_ = std::move(callback);
118   else
119     CancelCallback();
120   return rv;
121 }
122 
AcceptSocketDescriptor(SocketDescriptor * socket,CompletionOnceCallback callback)123 int UnixDomainServerSocket::AcceptSocketDescriptor(
124     SocketDescriptor* socket,
125     CompletionOnceCallback callback) {
126   DCHECK(socket);
127   DCHECK(callback);
128   DCHECK(!callback_ && !out_socket_.stream && !out_socket_.descriptor);
129 
130   out_socket_ = {nullptr, socket};
131   int rv = DoAccept();
132   if (rv == ERR_IO_PENDING)
133     callback_ = std::move(callback);
134   else
135     CancelCallback();
136   return rv;
137 }
138 
DoAccept()139 int UnixDomainServerSocket::DoAccept() {
140   DCHECK(listen_socket_);
141   DCHECK(!accept_socket_);
142 
143   while (true) {
144     int rv = listen_socket_->Accept(
145         &accept_socket_,
146         base::BindOnce(&UnixDomainServerSocket::AcceptCompleted,
147                        base::Unretained(this)));
148     if (rv != OK)
149       return rv;
150     if (AuthenticateAndGetStreamSocket())
151       return OK;
152     // Accept another socket because authentication error should be transparent
153     // to the caller.
154   }
155 }
156 
AcceptCompleted(int rv)157 void UnixDomainServerSocket::AcceptCompleted(int rv) {
158   DCHECK(!callback_.is_null());
159 
160   if (rv != OK) {
161     RunCallback(rv);
162     return;
163   }
164 
165   if (AuthenticateAndGetStreamSocket()) {
166     RunCallback(OK);
167     return;
168   }
169 
170   // Accept another socket because authentication error should be transparent
171   // to the caller.
172   rv = DoAccept();
173   if (rv != ERR_IO_PENDING)
174     RunCallback(rv);
175 }
176 
AuthenticateAndGetStreamSocket()177 bool UnixDomainServerSocket::AuthenticateAndGetStreamSocket() {
178   DCHECK(accept_socket_);
179 
180   Credentials credentials;
181   if (!GetPeerCredentials(accept_socket_->socket_fd(), &credentials) ||
182       !auth_callback_.Run(credentials)) {
183     accept_socket_.reset();
184     return false;
185   }
186 
187   SetSocketResult(std::move(accept_socket_));
188   return true;
189 }
190 
SetSocketResult(std::unique_ptr<SocketPosix> accepted_socket)191 void UnixDomainServerSocket::SetSocketResult(
192     std::unique_ptr<SocketPosix> accepted_socket) {
193   // Exactly one of the output pointers should be set.
194   DCHECK_NE(!!out_socket_.stream, !!out_socket_.descriptor);
195 
196   // Pass ownership of |accepted_socket|.
197   if (out_socket_.descriptor) {
198     *out_socket_.descriptor = accepted_socket->ReleaseConnectedSocket();
199     return;
200   }
201   *out_socket_.stream =
202       std::make_unique<UnixDomainClientSocket>(std::move(accepted_socket));
203 }
204 
RunCallback(int rv)205 void UnixDomainServerSocket::RunCallback(int rv) {
206   out_socket_ = SocketDestination();
207   std::move(callback_).Run(rv);
208 }
209 
CancelCallback()210 void UnixDomainServerSocket::CancelCallback() {
211   out_socket_ = SocketDestination();
212   callback_.Reset();
213 }
214 
215 }  // namespace net
216