1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "mojo/public/cpp/platform/named_platform_channel.h"
6 
7 #include <errno.h>
8 #include <sys/socket.h>
9 #include <sys/un.h>
10 #include <unistd.h>
11 
12 #include "base/files/file_util.h"
13 #include "base/files/scoped_file.h"
14 #include "base/logging.h"
15 #include "base/posix/eintr_wrapper.h"
16 #include "base/rand_util.h"
17 #include "base/strings/string_number_conversions.h"
18 
19 namespace mojo {
20 
21 namespace {
22 
GenerateRandomServerName(const NamedPlatformChannel::Options & options)23 NamedPlatformChannel::ServerName GenerateRandomServerName(
24     const NamedPlatformChannel::Options& options) {
25   return options.socket_dir
26       .AppendASCII(base::NumberToString(base::RandUint64()))
27       .value();
28 }
29 
30 // This function fills in |unix_addr| with the appropriate data for the socket,
31 // and sets |unix_addr_len| to the length of the data therein.
32 // Returns true on success, or false on failure (typically because |server_name|
33 // violated the naming rules).
MakeUnixAddr(const NamedPlatformChannel::ServerName & server_name,struct sockaddr_un * unix_addr,size_t * unix_addr_len)34 bool MakeUnixAddr(const NamedPlatformChannel::ServerName& server_name,
35                   struct sockaddr_un* unix_addr,
36                   size_t* unix_addr_len) {
37   DCHECK(unix_addr);
38   DCHECK(unix_addr_len);
39   DCHECK(!server_name.empty());
40 
41   constexpr size_t kMaxSocketNameLength = 104;
42 
43   // We reject server_name.length() == kMaxSocketNameLength to make room for the
44   // NUL terminator at the end of the string.
45   if (server_name.length() >= kMaxSocketNameLength) {
46     LOG(ERROR) << "Socket name too long: " << server_name;
47     return false;
48   }
49 
50   // Create unix_addr structure.
51   memset(unix_addr, 0, sizeof(struct sockaddr_un));
52   unix_addr->sun_family = AF_UNIX;
53   strncpy(unix_addr->sun_path, server_name.c_str(), kMaxSocketNameLength);
54   *unix_addr_len =
55       offsetof(struct sockaddr_un, sun_path) + server_name.length();
56   return true;
57 }
58 
59 // This function creates a unix domain socket, and set it as non-blocking.
60 // If successful, this returns a PlatformHandle containing the socket.
61 // Otherwise, this returns an invalid PlatformHandle.
CreateUnixDomainSocket()62 PlatformHandle CreateUnixDomainSocket() {
63   // Create the unix domain socket.
64   PlatformHandle handle(base::ScopedFD(socket(AF_UNIX, SOCK_STREAM, 0)));
65   if (!handle.is_valid()) {
66     PLOG(ERROR) << "Failed to create AF_UNIX socket.";
67     return PlatformHandle();
68   }
69 
70   // Now set it as non-blocking.
71   if (!base::SetNonBlocking(handle.GetFD().get())) {
72     PLOG(ERROR) << "base::SetNonBlocking() failed " << handle.GetFD().get();
73     return PlatformHandle();
74   }
75   return handle;
76 }
77 
78 }  // namespace
79 
80 // static
CreateServerEndpoint(const Options & options,ServerName * server_name)81 PlatformChannelServerEndpoint NamedPlatformChannel::CreateServerEndpoint(
82     const Options& options,
83     ServerName* server_name) {
84   ServerName name = options.server_name;
85   if (name.empty())
86     name = GenerateRandomServerName(options);
87 
88   // Make sure the path we need exists.
89   base::FilePath socket_dir = base::FilePath(name).DirName();
90   if (!base::CreateDirectory(socket_dir)) {
91     LOG(ERROR) << "Couldn't create directory: " << socket_dir.value();
92     return PlatformChannelServerEndpoint();
93   }
94 
95   // Delete any old FS instances.
96   if (unlink(name.c_str()) < 0 && errno != ENOENT) {
97     PLOG(ERROR) << "unlink " << name;
98     return PlatformChannelServerEndpoint();
99   }
100 
101   struct sockaddr_un unix_addr;
102   size_t unix_addr_len;
103   if (!MakeUnixAddr(name, &unix_addr, &unix_addr_len))
104     return PlatformChannelServerEndpoint();
105 
106   PlatformHandle handle = CreateUnixDomainSocket();
107   if (!handle.is_valid())
108     return PlatformChannelServerEndpoint();
109 
110   // Bind the socket.
111   if (bind(handle.GetFD().get(), reinterpret_cast<const sockaddr*>(&unix_addr),
112            unix_addr_len) < 0) {
113     PLOG(ERROR) << "bind " << name;
114     return PlatformChannelServerEndpoint();
115   }
116 
117   // Start listening on the socket.
118   if (listen(handle.GetFD().get(), SOMAXCONN) < 0) {
119     PLOG(ERROR) << "listen " << name;
120     unlink(name.c_str());
121     return PlatformChannelServerEndpoint();
122   }
123 
124   *server_name = name;
125   return PlatformChannelServerEndpoint(std::move(handle));
126 }
127 
128 // static
CreateClientEndpoint(const ServerName & server_name)129 PlatformChannelEndpoint NamedPlatformChannel::CreateClientEndpoint(
130     const ServerName& server_name) {
131   DCHECK(!server_name.empty());
132 
133   struct sockaddr_un unix_addr;
134   size_t unix_addr_len;
135   if (!MakeUnixAddr(server_name, &unix_addr, &unix_addr_len))
136     return PlatformChannelEndpoint();
137 
138   PlatformHandle handle = CreateUnixDomainSocket();
139   if (!handle.is_valid())
140     return PlatformChannelEndpoint();
141 
142   if (HANDLE_EINTR(connect(handle.GetFD().get(),
143                            reinterpret_cast<sockaddr*>(&unix_addr),
144                            unix_addr_len)) < 0) {
145     PLOG(ERROR) << "connect " << server_name;
146     return PlatformChannelEndpoint();
147   }
148   return PlatformChannelEndpoint(std::move(handle));
149 }
150 
151 }  // namespace mojo
152