1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "dispatch.h"
6
7 #include <cassert>
8 #include "cbor.h"
9 #include "error_support.h"
10 #include "find_by_first.h"
11 #include "frontend_channel.h"
12
13 namespace crdtp {
14 // =============================================================================
15 // DispatchResponse - Error status and chaining / fall through
16 // =============================================================================
17
18 // static
Success()19 DispatchResponse DispatchResponse::Success() {
20 DispatchResponse result;
21 result.code_ = DispatchCode::SUCCESS;
22 return result;
23 }
24
25 // static
FallThrough()26 DispatchResponse DispatchResponse::FallThrough() {
27 DispatchResponse result;
28 result.code_ = DispatchCode::FALL_THROUGH;
29 return result;
30 }
31
32 // static
ParseError(std::string message)33 DispatchResponse DispatchResponse::ParseError(std::string message) {
34 DispatchResponse result;
35 result.code_ = DispatchCode::PARSE_ERROR;
36 result.message_ = std::move(message);
37 return result;
38 }
39
40 // static
InvalidRequest(std::string message)41 DispatchResponse DispatchResponse::InvalidRequest(std::string message) {
42 DispatchResponse result;
43 result.code_ = DispatchCode::INVALID_REQUEST;
44 result.message_ = std::move(message);
45 return result;
46 }
47
48 // static
MethodNotFound(std::string message)49 DispatchResponse DispatchResponse::MethodNotFound(std::string message) {
50 DispatchResponse result;
51 result.code_ = DispatchCode::METHOD_NOT_FOUND;
52 result.message_ = std::move(message);
53 return result;
54 }
55
56 // static
InvalidParams(std::string message)57 DispatchResponse DispatchResponse::InvalidParams(std::string message) {
58 DispatchResponse result;
59 result.code_ = DispatchCode::INVALID_PARAMS;
60 result.message_ = std::move(message);
61 return result;
62 }
63
64 // static
InternalError()65 DispatchResponse DispatchResponse::InternalError() {
66 DispatchResponse result;
67 result.code_ = DispatchCode::INTERNAL_ERROR;
68 result.message_ = "Internal error";
69 return result;
70 }
71
72 // static
ServerError(std::string message)73 DispatchResponse DispatchResponse::ServerError(std::string message) {
74 DispatchResponse result;
75 result.code_ = DispatchCode::SERVER_ERROR;
76 result.message_ = std::move(message);
77 return result;
78 }
79
80 // =============================================================================
81 // Dispatchable - a shallow parser for CBOR encoded DevTools messages
82 // =============================================================================
83 namespace {
84 constexpr size_t kEncodedEnvelopeHeaderSize = 1 + 1 + sizeof(uint32_t);
85 } // namespace
86
Dispatchable(span<uint8_t> serialized)87 Dispatchable::Dispatchable(span<uint8_t> serialized) : serialized_(serialized) {
88 Status s = cbor::CheckCBORMessage(serialized);
89 if (!s.ok()) {
90 status_ = {Error::MESSAGE_MUST_BE_AN_OBJECT, s.pos};
91 return;
92 }
93 cbor::CBORTokenizer tokenizer(serialized);
94 if (tokenizer.TokenTag() == cbor::CBORTokenTag::ERROR_VALUE) {
95 status_ = tokenizer.Status();
96 return;
97 }
98
99 // We checked for the envelope start byte above, so the tokenizer
100 // must agree here, since it's not an error.
101 assert(tokenizer.TokenTag() == cbor::CBORTokenTag::ENVELOPE);
102
103 // Before we enter the envelope, we save the position that we
104 // expect to see after we're done parsing the envelope contents.
105 // This way we can compare and produce an error if the contents
106 // didn't fit exactly into the envelope length.
107 const size_t pos_past_envelope = tokenizer.Status().pos +
108 kEncodedEnvelopeHeaderSize +
109 tokenizer.GetEnvelopeContents().size();
110 tokenizer.EnterEnvelope();
111 if (tokenizer.TokenTag() == cbor::CBORTokenTag::ERROR_VALUE) {
112 status_ = tokenizer.Status();
113 return;
114 }
115 if (tokenizer.TokenTag() != cbor::CBORTokenTag::MAP_START) {
116 status_ = {Error::MESSAGE_MUST_BE_AN_OBJECT, tokenizer.Status().pos};
117 return;
118 }
119 assert(tokenizer.TokenTag() == cbor::CBORTokenTag::MAP_START);
120 tokenizer.Next(); // Now we should be pointed at the map key.
121 while (tokenizer.TokenTag() != cbor::CBORTokenTag::STOP) {
122 switch (tokenizer.TokenTag()) {
123 case cbor::CBORTokenTag::DONE:
124 status_ =
125 Status{Error::CBOR_UNEXPECTED_EOF_IN_MAP, tokenizer.Status().pos};
126 return;
127 case cbor::CBORTokenTag::ERROR_VALUE:
128 status_ = tokenizer.Status();
129 return;
130 case cbor::CBORTokenTag::STRING8:
131 if (!MaybeParseProperty(&tokenizer))
132 return;
133 break;
134 default:
135 // We require the top-level keys to be UTF8 (US-ASCII in practice).
136 status_ = Status{Error::CBOR_INVALID_MAP_KEY, tokenizer.Status().pos};
137 return;
138 }
139 }
140 tokenizer.Next();
141 if (!has_call_id_) {
142 status_ = Status{Error::MESSAGE_MUST_HAVE_INTEGER_ID_PROPERTY,
143 tokenizer.Status().pos};
144 return;
145 }
146 if (method_.empty()) {
147 status_ = Status{Error::MESSAGE_MUST_HAVE_STRING_METHOD_PROPERTY,
148 tokenizer.Status().pos};
149 return;
150 }
151 // The contents of the envelope parsed OK, now check that we're at
152 // the expected position.
153 if (pos_past_envelope != tokenizer.Status().pos) {
154 status_ = Status{Error::CBOR_ENVELOPE_CONTENTS_LENGTH_MISMATCH,
155 tokenizer.Status().pos};
156 return;
157 }
158 if (tokenizer.TokenTag() != cbor::CBORTokenTag::DONE) {
159 status_ = Status{Error::CBOR_TRAILING_JUNK, tokenizer.Status().pos};
160 return;
161 }
162 }
163
ok() const164 bool Dispatchable::ok() const {
165 return status_.ok();
166 }
167
DispatchError() const168 DispatchResponse Dispatchable::DispatchError() const {
169 // TODO(johannes): Replace with DCHECK / similar?
170 if (status_.ok())
171 return DispatchResponse::Success();
172
173 if (status_.IsMessageError())
174 return DispatchResponse::InvalidRequest(status_.Message());
175 return DispatchResponse::ParseError(status_.ToASCIIString());
176 }
177
MaybeParseProperty(cbor::CBORTokenizer * tokenizer)178 bool Dispatchable::MaybeParseProperty(cbor::CBORTokenizer* tokenizer) {
179 span<uint8_t> property_name = tokenizer->GetString8();
180 if (SpanEquals(SpanFrom("id"), property_name))
181 return MaybeParseCallId(tokenizer);
182 if (SpanEquals(SpanFrom("method"), property_name))
183 return MaybeParseMethod(tokenizer);
184 if (SpanEquals(SpanFrom("params"), property_name))
185 return MaybeParseParams(tokenizer);
186 if (SpanEquals(SpanFrom("sessionId"), property_name))
187 return MaybeParseSessionId(tokenizer);
188 status_ =
189 Status{Error::MESSAGE_HAS_UNKNOWN_PROPERTY, tokenizer->Status().pos};
190 return false;
191 }
192
MaybeParseCallId(cbor::CBORTokenizer * tokenizer)193 bool Dispatchable::MaybeParseCallId(cbor::CBORTokenizer* tokenizer) {
194 if (has_call_id_) {
195 status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
196 return false;
197 }
198 tokenizer->Next();
199 if (tokenizer->TokenTag() != cbor::CBORTokenTag::INT32) {
200 status_ = Status{Error::MESSAGE_MUST_HAVE_INTEGER_ID_PROPERTY,
201 tokenizer->Status().pos};
202 return false;
203 }
204 call_id_ = tokenizer->GetInt32();
205 has_call_id_ = true;
206 tokenizer->Next();
207 return true;
208 }
209
MaybeParseMethod(cbor::CBORTokenizer * tokenizer)210 bool Dispatchable::MaybeParseMethod(cbor::CBORTokenizer* tokenizer) {
211 if (!method_.empty()) {
212 status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
213 return false;
214 }
215 tokenizer->Next();
216 if (tokenizer->TokenTag() != cbor::CBORTokenTag::STRING8) {
217 status_ = Status{Error::MESSAGE_MUST_HAVE_STRING_METHOD_PROPERTY,
218 tokenizer->Status().pos};
219 return false;
220 }
221 method_ = tokenizer->GetString8();
222 tokenizer->Next();
223 return true;
224 }
225
MaybeParseParams(cbor::CBORTokenizer * tokenizer)226 bool Dispatchable::MaybeParseParams(cbor::CBORTokenizer* tokenizer) {
227 if (params_seen_) {
228 status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
229 return false;
230 }
231 params_seen_ = true;
232 tokenizer->Next();
233 if (tokenizer->TokenTag() == cbor::CBORTokenTag::NULL_VALUE) {
234 tokenizer->Next();
235 return true;
236 }
237 if (tokenizer->TokenTag() != cbor::CBORTokenTag::ENVELOPE) {
238 status_ = Status{Error::MESSAGE_MAY_HAVE_OBJECT_PARAMS_PROPERTY,
239 tokenizer->Status().pos};
240 return false;
241 }
242 params_ = tokenizer->GetEnvelope();
243 tokenizer->Next();
244 return true;
245 }
246
MaybeParseSessionId(cbor::CBORTokenizer * tokenizer)247 bool Dispatchable::MaybeParseSessionId(cbor::CBORTokenizer* tokenizer) {
248 if (!session_id_.empty()) {
249 status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
250 return false;
251 }
252 tokenizer->Next();
253 if (tokenizer->TokenTag() != cbor::CBORTokenTag::STRING8) {
254 status_ = Status{Error::MESSAGE_MAY_HAVE_STRING_SESSION_ID_PROPERTY,
255 tokenizer->Status().pos};
256 return false;
257 }
258 session_id_ = tokenizer->GetString8();
259 tokenizer->Next();
260 return true;
261 }
262
263 namespace {
264 class ProtocolError : public Serializable {
265 public:
ProtocolError(DispatchResponse dispatch_response)266 explicit ProtocolError(DispatchResponse dispatch_response)
267 : dispatch_response_(std::move(dispatch_response)) {}
268
AppendSerialized(std::vector<uint8_t> * out) const269 void AppendSerialized(std::vector<uint8_t>* out) const override {
270 Status status;
271 std::unique_ptr<ParserHandler> encoder = cbor::NewCBOREncoder(out, &status);
272 encoder->HandleMapBegin();
273 if (has_call_id_) {
274 encoder->HandleString8(SpanFrom("id"));
275 encoder->HandleInt32(call_id_);
276 }
277 encoder->HandleString8(SpanFrom("error"));
278 encoder->HandleMapBegin();
279 encoder->HandleString8(SpanFrom("code"));
280 encoder->HandleInt32(static_cast<int32_t>(dispatch_response_.Code()));
281 encoder->HandleString8(SpanFrom("message"));
282 encoder->HandleString8(SpanFrom(dispatch_response_.Message()));
283 if (!data_.empty()) {
284 encoder->HandleString8(SpanFrom("data"));
285 encoder->HandleString8(SpanFrom(data_));
286 }
287 encoder->HandleMapEnd();
288 encoder->HandleMapEnd();
289 assert(status.ok());
290 }
291
SetCallId(int call_id)292 void SetCallId(int call_id) {
293 has_call_id_ = true;
294 call_id_ = call_id;
295 }
SetData(std::string data)296 void SetData(std::string data) { data_ = std::move(data); }
297
298 private:
299 const DispatchResponse dispatch_response_;
300 std::string data_;
301 int call_id_ = 0;
302 bool has_call_id_ = false;
303 };
304 } // namespace
305
306 // =============================================================================
307 // Helpers for creating protocol cresponses and notifications.
308 // =============================================================================
309
CreateErrorResponse(int call_id,DispatchResponse dispatch_response,const ErrorSupport * errors)310 std::unique_ptr<Serializable> CreateErrorResponse(
311 int call_id,
312 DispatchResponse dispatch_response,
313 const ErrorSupport* errors) {
314 auto protocol_error =
315 std::make_unique<ProtocolError>(std::move(dispatch_response));
316 protocol_error->SetCallId(call_id);
317 if (errors && !errors->Errors().empty()) {
318 protocol_error->SetData(
319 std::string(errors->Errors().begin(), errors->Errors().end()));
320 }
321 return protocol_error;
322 }
323
CreateErrorNotification(DispatchResponse dispatch_response)324 std::unique_ptr<Serializable> CreateErrorNotification(
325 DispatchResponse dispatch_response) {
326 return std::make_unique<ProtocolError>(std::move(dispatch_response));
327 }
328
329 namespace {
330 class Response : public Serializable {
331 public:
Response(int call_id,std::unique_ptr<Serializable> params)332 Response(int call_id, std::unique_ptr<Serializable> params)
333 : call_id_(call_id), params_(std::move(params)) {}
334
AppendSerialized(std::vector<uint8_t> * out) const335 void AppendSerialized(std::vector<uint8_t>* out) const override {
336 Status status;
337 std::unique_ptr<ParserHandler> encoder = cbor::NewCBOREncoder(out, &status);
338 encoder->HandleMapBegin();
339 encoder->HandleString8(SpanFrom("id"));
340 encoder->HandleInt32(call_id_);
341 encoder->HandleString8(SpanFrom("result"));
342 if (params_) {
343 params_->AppendSerialized(out);
344 } else {
345 encoder->HandleMapBegin();
346 encoder->HandleMapEnd();
347 }
348 encoder->HandleMapEnd();
349 assert(status.ok());
350 }
351
352 private:
353 const int call_id_;
354 std::unique_ptr<Serializable> params_;
355 };
356
357 class Notification : public Serializable {
358 public:
Notification(const char * method,std::unique_ptr<Serializable> params)359 Notification(const char* method, std::unique_ptr<Serializable> params)
360 : method_(method), params_(std::move(params)) {}
361
AppendSerialized(std::vector<uint8_t> * out) const362 void AppendSerialized(std::vector<uint8_t>* out) const override {
363 Status status;
364 std::unique_ptr<ParserHandler> encoder = cbor::NewCBOREncoder(out, &status);
365 encoder->HandleMapBegin();
366 encoder->HandleString8(SpanFrom("method"));
367 encoder->HandleString8(SpanFrom(method_));
368 encoder->HandleString8(SpanFrom("params"));
369 if (params_) {
370 params_->AppendSerialized(out);
371 } else {
372 encoder->HandleMapBegin();
373 encoder->HandleMapEnd();
374 }
375 encoder->HandleMapEnd();
376 assert(status.ok());
377 }
378
379 private:
380 const char* method_;
381 std::unique_ptr<Serializable> params_;
382 };
383 } // namespace
384
CreateResponse(int call_id,std::unique_ptr<Serializable> params)385 std::unique_ptr<Serializable> CreateResponse(
386 int call_id,
387 std::unique_ptr<Serializable> params) {
388 return std::make_unique<Response>(call_id, std::move(params));
389 }
390
CreateNotification(const char * method,std::unique_ptr<Serializable> params)391 std::unique_ptr<Serializable> CreateNotification(
392 const char* method,
393 std::unique_ptr<Serializable> params) {
394 return std::make_unique<Notification>(method, std::move(params));
395 }
396
397 // =============================================================================
398 // DomainDispatcher - Dispatching betwen protocol methods within a domain.
399 // =============================================================================
WeakPtr(DomainDispatcher * dispatcher)400 DomainDispatcher::WeakPtr::WeakPtr(DomainDispatcher* dispatcher)
401 : dispatcher_(dispatcher) {}
402
~WeakPtr()403 DomainDispatcher::WeakPtr::~WeakPtr() {
404 if (dispatcher_)
405 dispatcher_->weak_ptrs_.erase(this);
406 }
407
408 DomainDispatcher::Callback::~Callback() = default;
409
dispose()410 void DomainDispatcher::Callback::dispose() {
411 backend_impl_ = nullptr;
412 }
413
Callback(std::unique_ptr<DomainDispatcher::WeakPtr> backend_impl,int call_id,span<uint8_t> method,span<uint8_t> message)414 DomainDispatcher::Callback::Callback(
415 std::unique_ptr<DomainDispatcher::WeakPtr> backend_impl,
416 int call_id,
417 span<uint8_t> method,
418 span<uint8_t> message)
419 : backend_impl_(std::move(backend_impl)),
420 call_id_(call_id),
421 method_(method),
422 message_(message.begin(), message.end()) {}
423
sendIfActive(std::unique_ptr<Serializable> partialMessage,const DispatchResponse & response)424 void DomainDispatcher::Callback::sendIfActive(
425 std::unique_ptr<Serializable> partialMessage,
426 const DispatchResponse& response) {
427 if (!backend_impl_ || !backend_impl_->get())
428 return;
429 backend_impl_->get()->sendResponse(call_id_, response,
430 std::move(partialMessage));
431 backend_impl_ = nullptr;
432 }
433
fallThroughIfActive()434 void DomainDispatcher::Callback::fallThroughIfActive() {
435 if (!backend_impl_ || !backend_impl_->get())
436 return;
437 backend_impl_->get()->channel()->FallThrough(call_id_, method_,
438 SpanFrom(message_));
439 backend_impl_ = nullptr;
440 }
441
DomainDispatcher(FrontendChannel * frontendChannel)442 DomainDispatcher::DomainDispatcher(FrontendChannel* frontendChannel)
443 : frontend_channel_(frontendChannel) {}
444
~DomainDispatcher()445 DomainDispatcher::~DomainDispatcher() {
446 clearFrontend();
447 }
448
sendResponse(int call_id,const DispatchResponse & response,std::unique_ptr<Serializable> result)449 void DomainDispatcher::sendResponse(int call_id,
450 const DispatchResponse& response,
451 std::unique_ptr<Serializable> result) {
452 if (!frontend_channel_)
453 return;
454 std::unique_ptr<Serializable> serializable;
455 if (response.IsError()) {
456 serializable = CreateErrorResponse(call_id, response);
457 } else {
458 serializable = CreateResponse(call_id, std::move(result));
459 }
460 frontend_channel_->SendProtocolResponse(call_id, std::move(serializable));
461 }
462
MaybeReportInvalidParams(const Dispatchable & dispatchable,const ErrorSupport & errors)463 bool DomainDispatcher::MaybeReportInvalidParams(
464 const Dispatchable& dispatchable,
465 const ErrorSupport& errors) {
466 if (errors.Errors().empty())
467 return false;
468 if (frontend_channel_) {
469 frontend_channel_->SendProtocolResponse(
470 dispatchable.CallId(),
471 CreateErrorResponse(
472 dispatchable.CallId(),
473 DispatchResponse::InvalidParams("Invalid parameters"), &errors));
474 }
475 return true;
476 }
477
clearFrontend()478 void DomainDispatcher::clearFrontend() {
479 frontend_channel_ = nullptr;
480 for (auto& weak : weak_ptrs_)
481 weak->dispose();
482 weak_ptrs_.clear();
483 }
484
weakPtr()485 std::unique_ptr<DomainDispatcher::WeakPtr> DomainDispatcher::weakPtr() {
486 auto weak = std::make_unique<DomainDispatcher::WeakPtr>(this);
487 weak_ptrs_.insert(weak.get());
488 return weak;
489 }
490
491 // =============================================================================
492 // UberDispatcher - dispatches between domains (backends).
493 // =============================================================================
DispatchResult(bool method_found,std::function<void ()> runnable)494 UberDispatcher::DispatchResult::DispatchResult(bool method_found,
495 std::function<void()> runnable)
496 : method_found_(method_found), runnable_(runnable) {}
497
Run()498 void UberDispatcher::DispatchResult::Run() {
499 if (!runnable_)
500 return;
501 runnable_();
502 runnable_ = nullptr;
503 }
504
UberDispatcher(FrontendChannel * frontend_channel)505 UberDispatcher::UberDispatcher(FrontendChannel* frontend_channel)
506 : frontend_channel_(frontend_channel) {
507 assert(frontend_channel);
508 }
509
510 UberDispatcher::~UberDispatcher() = default;
511
512 constexpr size_t kNotFound = std::numeric_limits<size_t>::max();
513
514 namespace {
DotIdx(span<uint8_t> method)515 size_t DotIdx(span<uint8_t> method) {
516 const void* p = memchr(method.data(), '.', method.size());
517 return p ? reinterpret_cast<const uint8_t*>(p) - method.data() : kNotFound;
518 }
519 } // namespace
520
Dispatch(const Dispatchable & dispatchable) const521 UberDispatcher::DispatchResult UberDispatcher::Dispatch(
522 const Dispatchable& dispatchable) const {
523 span<uint8_t> method = FindByFirst(redirects_, dispatchable.Method(),
524 /*default_value=*/dispatchable.Method());
525 size_t dot_idx = DotIdx(method);
526 if (dot_idx != kNotFound) {
527 span<uint8_t> domain = method.subspan(0, dot_idx);
528 span<uint8_t> command = method.subspan(dot_idx + 1);
529 DomainDispatcher* dispatcher = FindByFirst(dispatchers_, domain);
530 if (dispatcher) {
531 std::function<void(const Dispatchable&)> dispatched =
532 dispatcher->Dispatch(command);
533 if (dispatched) {
534 return DispatchResult(
535 true, [dispatchable, dispatched = std::move(dispatched)]() {
536 dispatched(dispatchable);
537 });
538 }
539 }
540 }
541 return DispatchResult(false, [this, dispatchable]() {
542 frontend_channel_->SendProtocolResponse(
543 dispatchable.CallId(),
544 CreateErrorResponse(dispatchable.CallId(),
545 DispatchResponse::MethodNotFound(
546 "'" +
547 std::string(dispatchable.Method().begin(),
548 dispatchable.Method().end()) +
549 "' wasn't found")));
550 });
551 }
552
553 template <typename T>
554 struct FirstLessThan {
operator ()crdtp::FirstLessThan555 bool operator()(const std::pair<span<uint8_t>, T>& left,
556 const std::pair<span<uint8_t>, T>& right) {
557 return SpanLessThan(left.first, right.first);
558 }
559 };
560
WireBackend(span<uint8_t> domain,const std::vector<std::pair<span<uint8_t>,span<uint8_t>>> & sorted_redirects,std::unique_ptr<DomainDispatcher> dispatcher)561 void UberDispatcher::WireBackend(
562 span<uint8_t> domain,
563 const std::vector<std::pair<span<uint8_t>, span<uint8_t>>>&
564 sorted_redirects,
565 std::unique_ptr<DomainDispatcher> dispatcher) {
566 auto it = redirects_.insert(redirects_.end(), sorted_redirects.begin(),
567 sorted_redirects.end());
568 std::inplace_merge(redirects_.begin(), it, redirects_.end(),
569 FirstLessThan<span<uint8_t>>());
570 auto jt = dispatchers_.insert(dispatchers_.end(),
571 std::make_pair(domain, std::move(dispatcher)));
572 std::inplace_merge(dispatchers_.begin(), jt, dispatchers_.end(),
573 FirstLessThan<std::unique_ptr<DomainDispatcher>>());
574 }
575
576 } // namespace crdtp
577