1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/ipc/host_impl.h"
18
19 #include <inttypes.h>
20
21 #include <algorithm>
22 #include <utility>
23
24 #include "perfetto/base/task_runner.h"
25 #include "perfetto/ext/base/utils.h"
26 #include "perfetto/ext/ipc/service.h"
27 #include "perfetto/ext/ipc/service_descriptor.h"
28
29 #include "protos/perfetto/ipc/wire_protocol.gen.h"
30
31 // TODO(primiano): put limits on #connections/uid and req. queue (b/69093705).
32
33 namespace perfetto {
34 namespace ipc {
35
36 // static
CreateInstance(const char * socket_name,base::TaskRunner * task_runner)37 std::unique_ptr<Host> Host::CreateInstance(const char* socket_name,
38 base::TaskRunner* task_runner) {
39 std::unique_ptr<HostImpl> host(new HostImpl(socket_name, task_runner));
40 if (!host->sock() || !host->sock()->is_listening())
41 return nullptr;
42 return std::unique_ptr<Host>(std::move(host));
43 }
44
45 // static
CreateInstance(base::ScopedFile socket_fd,base::TaskRunner * task_runner)46 std::unique_ptr<Host> Host::CreateInstance(base::ScopedFile socket_fd,
47 base::TaskRunner* task_runner) {
48 std::unique_ptr<HostImpl> host(
49 new HostImpl(std::move(socket_fd), task_runner));
50 if (!host->sock() || !host->sock()->is_listening())
51 return nullptr;
52 return std::unique_ptr<Host>(std::move(host));
53 }
54
HostImpl(base::ScopedFile socket_fd,base::TaskRunner * task_runner)55 HostImpl::HostImpl(base::ScopedFile socket_fd, base::TaskRunner* task_runner)
56 : task_runner_(task_runner), weak_ptr_factory_(this) {
57 PERFETTO_DCHECK_THREAD(thread_checker_);
58 sock_ = base::UnixSocket::Listen(std::move(socket_fd), this, task_runner_,
59 base::SockFamily::kUnix,
60 base::SockType::kStream);
61 }
62
HostImpl(const char * socket_name,base::TaskRunner * task_runner)63 HostImpl::HostImpl(const char* socket_name, base::TaskRunner* task_runner)
64 : task_runner_(task_runner), weak_ptr_factory_(this) {
65 PERFETTO_DCHECK_THREAD(thread_checker_);
66 sock_ = base::UnixSocket::Listen(socket_name, this, task_runner_,
67 base::SockFamily::kUnix,
68 base::SockType::kStream);
69 }
70
71 HostImpl::~HostImpl() = default;
72
ExposeService(std::unique_ptr<Service> service)73 bool HostImpl::ExposeService(std::unique_ptr<Service> service) {
74 PERFETTO_DCHECK_THREAD(thread_checker_);
75 const std::string& service_name = service->GetDescriptor().service_name;
76 if (GetServiceByName(service_name)) {
77 PERFETTO_DLOG("Duplicate ExposeService(): %s", service_name.c_str());
78 return false;
79 }
80 ServiceID sid = ++last_service_id_;
81 ExposedService exposed_service(sid, service_name, std::move(service));
82 services_.emplace(sid, std::move(exposed_service));
83 return true;
84 }
85
OnNewIncomingConnection(base::UnixSocket *,std::unique_ptr<base::UnixSocket> new_conn)86 void HostImpl::OnNewIncomingConnection(
87 base::UnixSocket*,
88 std::unique_ptr<base::UnixSocket> new_conn) {
89 PERFETTO_DCHECK_THREAD(thread_checker_);
90 std::unique_ptr<ClientConnection> client(new ClientConnection());
91 ClientID client_id = ++last_client_id_;
92 clients_by_socket_[new_conn.get()] = client.get();
93 client->id = client_id;
94 client->sock = std::move(new_conn);
95 clients_[client_id] = std::move(client);
96 }
97
OnDataAvailable(base::UnixSocket * sock)98 void HostImpl::OnDataAvailable(base::UnixSocket* sock) {
99 PERFETTO_DCHECK_THREAD(thread_checker_);
100 auto it = clients_by_socket_.find(sock);
101 if (it == clients_by_socket_.end())
102 return;
103 ClientConnection* client = it->second;
104 BufferedFrameDeserializer& frame_deserializer = client->frame_deserializer;
105
106 size_t rsize;
107 do {
108 auto buf = frame_deserializer.BeginReceive();
109 base::ScopedFile fd;
110 rsize = client->sock->Receive(buf.data, buf.size, &fd);
111 if (fd) {
112 PERFETTO_DCHECK(!client->received_fd);
113 client->received_fd = std::move(fd);
114 }
115 if (!frame_deserializer.EndReceive(rsize))
116 return OnDisconnect(client->sock.get());
117 } while (rsize > 0);
118
119 for (;;) {
120 std::unique_ptr<Frame> frame = frame_deserializer.PopNextFrame();
121 if (!frame)
122 break;
123 OnReceivedFrame(client, *frame);
124 }
125 }
126
OnReceivedFrame(ClientConnection * client,const Frame & req_frame)127 void HostImpl::OnReceivedFrame(ClientConnection* client,
128 const Frame& req_frame) {
129 if (req_frame.has_msg_bind_service())
130 return OnBindService(client, req_frame);
131 if (req_frame.has_msg_invoke_method())
132 return OnInvokeMethod(client, req_frame);
133
134 PERFETTO_DLOG("Received invalid RPC frame from client %" PRIu64, client->id);
135 Frame reply_frame;
136 reply_frame.set_request_id(req_frame.request_id());
137 reply_frame.mutable_msg_request_error()->set_error("unknown request");
138 SendFrame(client, reply_frame);
139 }
140
OnBindService(ClientConnection * client,const Frame & req_frame)141 void HostImpl::OnBindService(ClientConnection* client, const Frame& req_frame) {
142 // Binding a service doesn't do anything major. It just returns back the
143 // service id and its method map.
144 const Frame::BindService& req = req_frame.msg_bind_service();
145 Frame reply_frame;
146 reply_frame.set_request_id(req_frame.request_id());
147 auto* reply = reply_frame.mutable_msg_bind_service_reply();
148 const ExposedService* service = GetServiceByName(req.service_name());
149 if (service) {
150 reply->set_success(true);
151 reply->set_service_id(service->id);
152 uint32_t method_id = 1; // method ids start at index 1.
153 for (const auto& desc_method : service->instance->GetDescriptor().methods) {
154 Frame::BindServiceReply::MethodInfo* method_info = reply->add_methods();
155 method_info->set_name(desc_method.name);
156 method_info->set_id(method_id++);
157 }
158 }
159 SendFrame(client, reply_frame);
160 }
161
OnInvokeMethod(ClientConnection * client,const Frame & req_frame)162 void HostImpl::OnInvokeMethod(ClientConnection* client,
163 const Frame& req_frame) {
164 const Frame::InvokeMethod& req = req_frame.msg_invoke_method();
165 Frame reply_frame;
166 RequestID request_id = req_frame.request_id();
167 reply_frame.set_request_id(request_id);
168 reply_frame.mutable_msg_invoke_method_reply()->set_success(false);
169 auto svc_it = services_.find(req.service_id());
170 if (svc_it == services_.end())
171 return SendFrame(client, reply_frame); // |success| == false by default.
172
173 Service* service = svc_it->second.instance.get();
174 const ServiceDescriptor& svc = service->GetDescriptor();
175 const auto& methods = svc.methods;
176 const uint32_t method_id = req.method_id();
177 if (method_id == 0 || method_id > methods.size())
178 return SendFrame(client, reply_frame);
179
180 const ServiceDescriptor::Method& method = methods[method_id - 1];
181 std::unique_ptr<ProtoMessage> decoded_req_args(
182 method.request_proto_decoder(req.args_proto()));
183 if (!decoded_req_args)
184 return SendFrame(client, reply_frame);
185
186 Deferred<ProtoMessage> deferred_reply;
187 base::WeakPtr<HostImpl> host_weak_ptr = weak_ptr_factory_.GetWeakPtr();
188 ClientID client_id = client->id;
189
190 if (!req.drop_reply()) {
191 deferred_reply.Bind([host_weak_ptr, client_id,
192 request_id](AsyncResult<ProtoMessage> reply) {
193 if (!host_weak_ptr)
194 return; // The reply came too late, the HostImpl has gone.
195 host_weak_ptr->ReplyToMethodInvocation(client_id, request_id,
196 std::move(reply));
197 });
198 }
199
200 service->client_info_ = ClientInfo(client->id, client->sock->peer_uid());
201 service->received_fd_ = &client->received_fd;
202 method.invoker(service, *decoded_req_args, std::move(deferred_reply));
203 service->received_fd_ = nullptr;
204 service->client_info_ = ClientInfo();
205 }
206
ReplyToMethodInvocation(ClientID client_id,RequestID request_id,AsyncResult<ProtoMessage> reply)207 void HostImpl::ReplyToMethodInvocation(ClientID client_id,
208 RequestID request_id,
209 AsyncResult<ProtoMessage> reply) {
210 auto client_iter = clients_.find(client_id);
211 if (client_iter == clients_.end())
212 return; // client has disconnected by the time we got the async reply.
213
214 ClientConnection* client = client_iter->second.get();
215 Frame reply_frame;
216 reply_frame.set_request_id(request_id);
217
218 // TODO(fmayer): add a test to guarantee that the reply is consumed within the
219 // same call stack and not kept around. ConsumerIPCService::OnTraceData()
220 // relies on this behavior.
221 auto* reply_frame_data = reply_frame.mutable_msg_invoke_method_reply();
222 reply_frame_data->set_has_more(reply.has_more());
223 if (reply.success()) {
224 std::string reply_proto = reply->SerializeAsString();
225 reply_frame_data->set_reply_proto(reply_proto);
226 reply_frame_data->set_success(true);
227 }
228 SendFrame(client, reply_frame, reply.fd());
229 }
230
231 // static
SendFrame(ClientConnection * client,const Frame & frame,int fd)232 void HostImpl::SendFrame(ClientConnection* client, const Frame& frame, int fd) {
233 std::string buf = BufferedFrameDeserializer::Serialize(frame);
234
235 // TODO(primiano): this should do non-blocking I/O. But then what if the
236 // socket buffer is full? We might want to either drop the request or throttle
237 // the send and PostTask the reply later? Right now we are making Send()
238 // blocking as a workaround. Propagate bakpressure to the caller instead.
239 bool res = client->sock->Send(buf.data(), buf.size(), fd);
240 PERFETTO_CHECK(res || !client->sock->is_connected());
241 }
242
OnDisconnect(base::UnixSocket * sock)243 void HostImpl::OnDisconnect(base::UnixSocket* sock) {
244 PERFETTO_DCHECK_THREAD(thread_checker_);
245 auto it = clients_by_socket_.find(sock);
246 if (it == clients_by_socket_.end())
247 return;
248 ClientID client_id = it->second->id;
249 ClientInfo client_info(client_id, sock->peer_uid());
250 clients_by_socket_.erase(it);
251 PERFETTO_DCHECK(clients_.count(client_id));
252 clients_.erase(client_id);
253
254 for (const auto& service_it : services_) {
255 Service& service = *service_it.second.instance;
256 service.client_info_ = client_info;
257 service.OnClientDisconnected();
258 service.client_info_ = ClientInfo();
259 }
260 }
261
GetServiceByName(const std::string & name)262 const HostImpl::ExposedService* HostImpl::GetServiceByName(
263 const std::string& name) {
264 // This could be optimized by using another map<name,ServiceID>. However this
265 // is used only by Bind/ExposeService that are quite rare (once per client
266 // connection and once per service instance), not worth it.
267 for (const auto& it : services_) {
268 if (it.second.name == name)
269 return &it.second;
270 }
271 return nullptr;
272 }
273
ExposedService(ServiceID id_,const std::string & name_,std::unique_ptr<Service> instance_)274 HostImpl::ExposedService::ExposedService(ServiceID id_,
275 const std::string& name_,
276 std::unique_ptr<Service> instance_)
277 : id(id_), name(name_), instance(std::move(instance_)) {}
278
279 HostImpl::ExposedService::ExposedService(ExposedService&&) noexcept = default;
280 HostImpl::ExposedService& HostImpl::ExposedService::operator=(
281 HostImpl::ExposedService&&) = default;
282 HostImpl::ExposedService::~ExposedService() = default;
283
284 HostImpl::ClientConnection::~ClientConnection() = default;
285
286 } // namespace ipc
287 } // namespace perfetto
288