1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "mojo/core/broker.h"
6 
7 #include <fcntl.h>
8 #include <unistd.h>
9 
10 #include <utility>
11 #include <vector>
12 
13 #include "base/logging.h"
14 #include "base/memory/platform_shared_memory_region.h"
15 #include "build/build_config.h"
16 #include "mojo/core/broker_messages.h"
17 #include "mojo/core/channel.h"
18 #include "mojo/core/platform_handle_utils.h"
19 #include "mojo/public/cpp/platform/socket_utils_posix.h"
20 
21 namespace mojo {
22 namespace core {
23 
24 namespace {
25 
WaitForBrokerMessage(int socket_fd,BrokerMessageType expected_type,size_t expected_num_handles,size_t expected_data_size,std::vector<PlatformHandle> * incoming_handles)26 Channel::MessagePtr WaitForBrokerMessage(
27     int socket_fd,
28     BrokerMessageType expected_type,
29     size_t expected_num_handles,
30     size_t expected_data_size,
31     std::vector<PlatformHandle>* incoming_handles) {
32   Channel::MessagePtr message(new Channel::Message(
33       sizeof(BrokerMessageHeader) + expected_data_size, expected_num_handles));
34   std::vector<base::ScopedFD> incoming_fds;
35   ssize_t read_result =
36       SocketRecvmsg(socket_fd, const_cast<void*>(message->data()),
37                     message->data_num_bytes(), &incoming_fds, true /* block */);
38   bool error = false;
39   if (read_result < 0) {
40     PLOG(ERROR) << "Recvmsg error";
41     error = true;
42   } else if (static_cast<size_t>(read_result) != message->data_num_bytes()) {
43     LOG(ERROR) << "Invalid node channel message";
44     error = true;
45   } else if (incoming_fds.size() != expected_num_handles) {
46     LOG(ERROR) << "Received unexpected number of handles";
47     error = true;
48   }
49 
50   if (error)
51     return nullptr;
52 
53   const BrokerMessageHeader* header =
54       reinterpret_cast<const BrokerMessageHeader*>(message->payload());
55   if (header->type != expected_type) {
56     LOG(ERROR) << "Unexpected message";
57     return nullptr;
58   }
59 
60   incoming_handles->reserve(incoming_fds.size());
61   for (size_t i = 0; i < incoming_fds.size(); ++i)
62     incoming_handles->emplace_back(std::move(incoming_fds[i]));
63 
64   return message;
65 }
66 
67 }  // namespace
68 
Broker(PlatformHandle handle,bool wait_for_channel_handle)69 Broker::Broker(PlatformHandle handle, bool wait_for_channel_handle)
70     : sync_channel_(std::move(handle)) {
71   CHECK(sync_channel_.is_valid());
72 
73   int fd = sync_channel_.GetFD().get();
74   // Mark the channel as blocking.
75   int flags = fcntl(fd, F_GETFL);
76   PCHECK(flags != -1);
77   flags = fcntl(fd, F_SETFL, flags & ~O_NONBLOCK);
78   PCHECK(flags != -1);
79 
80   if (!wait_for_channel_handle)
81     return;
82 
83   // Wait for the first message, which should contain a handle.
84   std::vector<PlatformHandle> incoming_platform_handles;
85   if (WaitForBrokerMessage(fd, BrokerMessageType::INIT, 1, 0,
86                            &incoming_platform_handles)) {
87     inviter_endpoint_ =
88         PlatformChannelEndpoint(std::move(incoming_platform_handles[0]));
89   }
90 }
91 
92 Broker::~Broker() = default;
93 
GetInviterEndpoint()94 PlatformChannelEndpoint Broker::GetInviterEndpoint() {
95   return std::move(inviter_endpoint_);
96 }
97 
GetWritableSharedMemoryRegion(size_t num_bytes)98 base::WritableSharedMemoryRegion Broker::GetWritableSharedMemoryRegion(
99     size_t num_bytes) {
100   base::AutoLock lock(lock_);
101 
102   BufferRequestData* buffer_request;
103   Channel::MessagePtr out_message = CreateBrokerMessage(
104       BrokerMessageType::BUFFER_REQUEST, 0, 0, &buffer_request);
105   buffer_request->size = num_bytes;
106   ssize_t write_result =
107       SocketWrite(sync_channel_.GetFD().get(), out_message->data(),
108                   out_message->data_num_bytes());
109   if (write_result < 0) {
110     PLOG(ERROR) << "Error sending sync broker message";
111     return base::WritableSharedMemoryRegion();
112   } else if (static_cast<size_t>(write_result) !=
113              out_message->data_num_bytes()) {
114     LOG(ERROR) << "Error sending complete broker message";
115     return base::WritableSharedMemoryRegion();
116   }
117 
118 #if !defined(OS_POSIX) || defined(OS_ANDROID) || defined(OS_MAC)
119   // Non-POSIX systems, as well as Android and Mac, only use a single handle to
120   // represent a writable region.
121   constexpr size_t kNumExpectedHandles = 1;
122 #else
123   constexpr size_t kNumExpectedHandles = 2;
124 #endif
125 
126   std::vector<PlatformHandle> handles;
127   Channel::MessagePtr message = WaitForBrokerMessage(
128       sync_channel_.GetFD().get(), BrokerMessageType::BUFFER_RESPONSE,
129       kNumExpectedHandles, sizeof(BufferResponseData), &handles);
130   if (message) {
131     const BufferResponseData* data;
132     if (!GetBrokerMessageData(message.get(), &data))
133       return base::WritableSharedMemoryRegion();
134 
135     if (handles.size() == 1)
136       handles.emplace_back();
137     return base::WritableSharedMemoryRegion::Deserialize(
138         base::subtle::PlatformSharedMemoryRegion::Take(
139             CreateSharedMemoryRegionHandleFromPlatformHandles(
140                 std::move(handles[0]), std::move(handles[1])),
141             base::subtle::PlatformSharedMemoryRegion::Mode::kWritable,
142             num_bytes,
143             base::UnguessableToken::Deserialize(data->guid_high,
144                                                 data->guid_low)));
145   }
146 
147   return base::WritableSharedMemoryRegion();
148 }
149 
150 }  // namespace core
151 }  // namespace mojo
152