1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "dispatch.h"
6 
7 #include <cassert>
8 #include "cbor.h"
9 #include "error_support.h"
10 #include "find_by_first.h"
11 #include "frontend_channel.h"
12 
13 namespace crdtp {
14 // =============================================================================
15 // DispatchResponse - Error status and chaining / fall through
16 // =============================================================================
17 
18 // static
Success()19 DispatchResponse DispatchResponse::Success() {
20   DispatchResponse result;
21   result.code_ = DispatchCode::SUCCESS;
22   return result;
23 }
24 
25 // static
FallThrough()26 DispatchResponse DispatchResponse::FallThrough() {
27   DispatchResponse result;
28   result.code_ = DispatchCode::FALL_THROUGH;
29   return result;
30 }
31 
32 // static
ParseError(std::string message)33 DispatchResponse DispatchResponse::ParseError(std::string message) {
34   DispatchResponse result;
35   result.code_ = DispatchCode::PARSE_ERROR;
36   result.message_ = std::move(message);
37   return result;
38 }
39 
40 // static
InvalidRequest(std::string message)41 DispatchResponse DispatchResponse::InvalidRequest(std::string message) {
42   DispatchResponse result;
43   result.code_ = DispatchCode::INVALID_REQUEST;
44   result.message_ = std::move(message);
45   return result;
46 }
47 
48 // static
MethodNotFound(std::string message)49 DispatchResponse DispatchResponse::MethodNotFound(std::string message) {
50   DispatchResponse result;
51   result.code_ = DispatchCode::METHOD_NOT_FOUND;
52   result.message_ = std::move(message);
53   return result;
54 }
55 
56 // static
InvalidParams(std::string message)57 DispatchResponse DispatchResponse::InvalidParams(std::string message) {
58   DispatchResponse result;
59   result.code_ = DispatchCode::INVALID_PARAMS;
60   result.message_ = std::move(message);
61   return result;
62 }
63 
64 // static
InternalError()65 DispatchResponse DispatchResponse::InternalError() {
66   DispatchResponse result;
67   result.code_ = DispatchCode::INTERNAL_ERROR;
68   result.message_ = "Internal error";
69   return result;
70 }
71 
72 // static
ServerError(std::string message)73 DispatchResponse DispatchResponse::ServerError(std::string message) {
74   DispatchResponse result;
75   result.code_ = DispatchCode::SERVER_ERROR;
76   result.message_ = std::move(message);
77   return result;
78 }
79 
80 // =============================================================================
81 // Dispatchable - a shallow parser for CBOR encoded DevTools messages
82 // =============================================================================
83 namespace {
84 constexpr size_t kEncodedEnvelopeHeaderSize = 1 + 1 + sizeof(uint32_t);
85 }  // namespace
86 
Dispatchable(span<uint8_t> serialized)87 Dispatchable::Dispatchable(span<uint8_t> serialized) : serialized_(serialized) {
88   Status s = cbor::CheckCBORMessage(serialized);
89   if (!s.ok()) {
90     status_ = {Error::MESSAGE_MUST_BE_AN_OBJECT, s.pos};
91     return;
92   }
93   cbor::CBORTokenizer tokenizer(serialized);
94   if (tokenizer.TokenTag() == cbor::CBORTokenTag::ERROR_VALUE) {
95     status_ = tokenizer.Status();
96     return;
97   }
98 
99   // We checked for the envelope start byte above, so the tokenizer
100   // must agree here, since it's not an error.
101   assert(tokenizer.TokenTag() == cbor::CBORTokenTag::ENVELOPE);
102 
103   // Before we enter the envelope, we save the position that we
104   // expect to see after we're done parsing the envelope contents.
105   // This way we can compare and produce an error if the contents
106   // didn't fit exactly into the envelope length.
107   const size_t pos_past_envelope = tokenizer.Status().pos +
108                                    kEncodedEnvelopeHeaderSize +
109                                    tokenizer.GetEnvelopeContents().size();
110   tokenizer.EnterEnvelope();
111   if (tokenizer.TokenTag() == cbor::CBORTokenTag::ERROR_VALUE) {
112     status_ = tokenizer.Status();
113     return;
114   }
115   if (tokenizer.TokenTag() != cbor::CBORTokenTag::MAP_START) {
116     status_ = {Error::MESSAGE_MUST_BE_AN_OBJECT, tokenizer.Status().pos};
117     return;
118   }
119   assert(tokenizer.TokenTag() == cbor::CBORTokenTag::MAP_START);
120   tokenizer.Next();  // Now we should be pointed at the map key.
121   while (tokenizer.TokenTag() != cbor::CBORTokenTag::STOP) {
122     switch (tokenizer.TokenTag()) {
123       case cbor::CBORTokenTag::DONE:
124         status_ =
125             Status{Error::CBOR_UNEXPECTED_EOF_IN_MAP, tokenizer.Status().pos};
126         return;
127       case cbor::CBORTokenTag::ERROR_VALUE:
128         status_ = tokenizer.Status();
129         return;
130       case cbor::CBORTokenTag::STRING8:
131         if (!MaybeParseProperty(&tokenizer))
132           return;
133         break;
134       default:
135         // We require the top-level keys to be UTF8 (US-ASCII in practice).
136         status_ = Status{Error::CBOR_INVALID_MAP_KEY, tokenizer.Status().pos};
137         return;
138     }
139   }
140   tokenizer.Next();
141   if (!has_call_id_) {
142     status_ = Status{Error::MESSAGE_MUST_HAVE_INTEGER_ID_PROPERTY,
143                      tokenizer.Status().pos};
144     return;
145   }
146   if (method_.empty()) {
147     status_ = Status{Error::MESSAGE_MUST_HAVE_STRING_METHOD_PROPERTY,
148                      tokenizer.Status().pos};
149     return;
150   }
151   // The contents of the envelope parsed OK, now check that we're at
152   // the expected position.
153   if (pos_past_envelope != tokenizer.Status().pos) {
154     status_ = Status{Error::CBOR_ENVELOPE_CONTENTS_LENGTH_MISMATCH,
155                      tokenizer.Status().pos};
156     return;
157   }
158   if (tokenizer.TokenTag() != cbor::CBORTokenTag::DONE) {
159     status_ = Status{Error::CBOR_TRAILING_JUNK, tokenizer.Status().pos};
160     return;
161   }
162 }
163 
ok() const164 bool Dispatchable::ok() const {
165   return status_.ok();
166 }
167 
DispatchError() const168 DispatchResponse Dispatchable::DispatchError() const {
169   // TODO(johannes): Replace with DCHECK / similar?
170   if (status_.ok())
171     return DispatchResponse::Success();
172 
173   if (status_.IsMessageError())
174     return DispatchResponse::InvalidRequest(status_.Message());
175   return DispatchResponse::ParseError(status_.ToASCIIString());
176 }
177 
MaybeParseProperty(cbor::CBORTokenizer * tokenizer)178 bool Dispatchable::MaybeParseProperty(cbor::CBORTokenizer* tokenizer) {
179   span<uint8_t> property_name = tokenizer->GetString8();
180   if (SpanEquals(SpanFrom("id"), property_name))
181     return MaybeParseCallId(tokenizer);
182   if (SpanEquals(SpanFrom("method"), property_name))
183     return MaybeParseMethod(tokenizer);
184   if (SpanEquals(SpanFrom("params"), property_name))
185     return MaybeParseParams(tokenizer);
186   if (SpanEquals(SpanFrom("sessionId"), property_name))
187     return MaybeParseSessionId(tokenizer);
188   status_ =
189       Status{Error::MESSAGE_HAS_UNKNOWN_PROPERTY, tokenizer->Status().pos};
190   return false;
191 }
192 
MaybeParseCallId(cbor::CBORTokenizer * tokenizer)193 bool Dispatchable::MaybeParseCallId(cbor::CBORTokenizer* tokenizer) {
194   if (has_call_id_) {
195     status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
196     return false;
197   }
198   tokenizer->Next();
199   if (tokenizer->TokenTag() != cbor::CBORTokenTag::INT32) {
200     status_ = Status{Error::MESSAGE_MUST_HAVE_INTEGER_ID_PROPERTY,
201                      tokenizer->Status().pos};
202     return false;
203   }
204   call_id_ = tokenizer->GetInt32();
205   has_call_id_ = true;
206   tokenizer->Next();
207   return true;
208 }
209 
MaybeParseMethod(cbor::CBORTokenizer * tokenizer)210 bool Dispatchable::MaybeParseMethod(cbor::CBORTokenizer* tokenizer) {
211   if (!method_.empty()) {
212     status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
213     return false;
214   }
215   tokenizer->Next();
216   if (tokenizer->TokenTag() != cbor::CBORTokenTag::STRING8) {
217     status_ = Status{Error::MESSAGE_MUST_HAVE_STRING_METHOD_PROPERTY,
218                      tokenizer->Status().pos};
219     return false;
220   }
221   method_ = tokenizer->GetString8();
222   tokenizer->Next();
223   return true;
224 }
225 
MaybeParseParams(cbor::CBORTokenizer * tokenizer)226 bool Dispatchable::MaybeParseParams(cbor::CBORTokenizer* tokenizer) {
227   if (params_seen_) {
228     status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
229     return false;
230   }
231   params_seen_ = true;
232   tokenizer->Next();
233   if (tokenizer->TokenTag() == cbor::CBORTokenTag::NULL_VALUE) {
234     tokenizer->Next();
235     return true;
236   }
237   if (tokenizer->TokenTag() != cbor::CBORTokenTag::ENVELOPE) {
238     status_ = Status{Error::MESSAGE_MAY_HAVE_OBJECT_PARAMS_PROPERTY,
239                      tokenizer->Status().pos};
240     return false;
241   }
242   params_ = tokenizer->GetEnvelope();
243   tokenizer->Next();
244   return true;
245 }
246 
MaybeParseSessionId(cbor::CBORTokenizer * tokenizer)247 bool Dispatchable::MaybeParseSessionId(cbor::CBORTokenizer* tokenizer) {
248   if (!session_id_.empty()) {
249     status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
250     return false;
251   }
252   tokenizer->Next();
253   if (tokenizer->TokenTag() != cbor::CBORTokenTag::STRING8) {
254     status_ = Status{Error::MESSAGE_MAY_HAVE_STRING_SESSION_ID_PROPERTY,
255                      tokenizer->Status().pos};
256     return false;
257   }
258   session_id_ = tokenizer->GetString8();
259   tokenizer->Next();
260   return true;
261 }
262 
263 namespace {
264 class ProtocolError : public Serializable {
265  public:
ProtocolError(DispatchResponse dispatch_response)266   explicit ProtocolError(DispatchResponse dispatch_response)
267       : dispatch_response_(std::move(dispatch_response)) {}
268 
AppendSerialized(std::vector<uint8_t> * out) const269   void AppendSerialized(std::vector<uint8_t>* out) const override {
270     Status status;
271     std::unique_ptr<ParserHandler> encoder = cbor::NewCBOREncoder(out, &status);
272     encoder->HandleMapBegin();
273     if (has_call_id_) {
274       encoder->HandleString8(SpanFrom("id"));
275       encoder->HandleInt32(call_id_);
276     }
277     encoder->HandleString8(SpanFrom("error"));
278     encoder->HandleMapBegin();
279     encoder->HandleString8(SpanFrom("code"));
280     encoder->HandleInt32(static_cast<int32_t>(dispatch_response_.Code()));
281     encoder->HandleString8(SpanFrom("message"));
282     encoder->HandleString8(SpanFrom(dispatch_response_.Message()));
283     if (!data_.empty()) {
284       encoder->HandleString8(SpanFrom("data"));
285       encoder->HandleString8(SpanFrom(data_));
286     }
287     encoder->HandleMapEnd();
288     encoder->HandleMapEnd();
289     assert(status.ok());
290   }
291 
SetCallId(int call_id)292   void SetCallId(int call_id) {
293     has_call_id_ = true;
294     call_id_ = call_id;
295   }
SetData(std::string data)296   void SetData(std::string data) { data_ = std::move(data); }
297 
298  private:
299   const DispatchResponse dispatch_response_;
300   std::string data_;
301   int call_id_ = 0;
302   bool has_call_id_ = false;
303 };
304 }  // namespace
305 
306 // =============================================================================
307 // Helpers for creating protocol cresponses and notifications.
308 // =============================================================================
309 
CreateErrorResponse(int call_id,DispatchResponse dispatch_response,const ErrorSupport * errors)310 std::unique_ptr<Serializable> CreateErrorResponse(
311     int call_id,
312     DispatchResponse dispatch_response,
313     const ErrorSupport* errors) {
314   auto protocol_error =
315       std::make_unique<ProtocolError>(std::move(dispatch_response));
316   protocol_error->SetCallId(call_id);
317   if (errors && !errors->Errors().empty()) {
318     protocol_error->SetData(
319         std::string(errors->Errors().begin(), errors->Errors().end()));
320   }
321   return protocol_error;
322 }
323 
CreateErrorNotification(DispatchResponse dispatch_response)324 std::unique_ptr<Serializable> CreateErrorNotification(
325     DispatchResponse dispatch_response) {
326   return std::make_unique<ProtocolError>(std::move(dispatch_response));
327 }
328 
329 namespace {
330 class Response : public Serializable {
331  public:
Response(int call_id,std::unique_ptr<Serializable> params)332   Response(int call_id, std::unique_ptr<Serializable> params)
333       : call_id_(call_id), params_(std::move(params)) {}
334 
AppendSerialized(std::vector<uint8_t> * out) const335   void AppendSerialized(std::vector<uint8_t>* out) const override {
336     Status status;
337     std::unique_ptr<ParserHandler> encoder = cbor::NewCBOREncoder(out, &status);
338     encoder->HandleMapBegin();
339     encoder->HandleString8(SpanFrom("id"));
340     encoder->HandleInt32(call_id_);
341     encoder->HandleString8(SpanFrom("result"));
342     if (params_) {
343       params_->AppendSerialized(out);
344     } else {
345       encoder->HandleMapBegin();
346       encoder->HandleMapEnd();
347     }
348     encoder->HandleMapEnd();
349     assert(status.ok());
350   }
351 
352  private:
353   const int call_id_;
354   std::unique_ptr<Serializable> params_;
355 };
356 
357 class Notification : public Serializable {
358  public:
Notification(const char * method,std::unique_ptr<Serializable> params)359   Notification(const char* method, std::unique_ptr<Serializable> params)
360       : method_(method), params_(std::move(params)) {}
361 
AppendSerialized(std::vector<uint8_t> * out) const362   void AppendSerialized(std::vector<uint8_t>* out) const override {
363     Status status;
364     std::unique_ptr<ParserHandler> encoder = cbor::NewCBOREncoder(out, &status);
365     encoder->HandleMapBegin();
366     encoder->HandleString8(SpanFrom("method"));
367     encoder->HandleString8(SpanFrom(method_));
368     encoder->HandleString8(SpanFrom("params"));
369     if (params_) {
370       params_->AppendSerialized(out);
371     } else {
372       encoder->HandleMapBegin();
373       encoder->HandleMapEnd();
374     }
375     encoder->HandleMapEnd();
376     assert(status.ok());
377   }
378 
379  private:
380   const char* method_;
381   std::unique_ptr<Serializable> params_;
382 };
383 }  // namespace
384 
CreateResponse(int call_id,std::unique_ptr<Serializable> params)385 std::unique_ptr<Serializable> CreateResponse(
386     int call_id,
387     std::unique_ptr<Serializable> params) {
388   return std::make_unique<Response>(call_id, std::move(params));
389 }
390 
CreateNotification(const char * method,std::unique_ptr<Serializable> params)391 std::unique_ptr<Serializable> CreateNotification(
392     const char* method,
393     std::unique_ptr<Serializable> params) {
394   return std::make_unique<Notification>(method, std::move(params));
395 }
396 
397 // =============================================================================
398 // DomainDispatcher - Dispatching betwen protocol methods within a domain.
399 // =============================================================================
WeakPtr(DomainDispatcher * dispatcher)400 DomainDispatcher::WeakPtr::WeakPtr(DomainDispatcher* dispatcher)
401     : dispatcher_(dispatcher) {}
402 
~WeakPtr()403 DomainDispatcher::WeakPtr::~WeakPtr() {
404   if (dispatcher_)
405     dispatcher_->weak_ptrs_.erase(this);
406 }
407 
408 DomainDispatcher::Callback::~Callback() = default;
409 
dispose()410 void DomainDispatcher::Callback::dispose() {
411   backend_impl_ = nullptr;
412 }
413 
Callback(std::unique_ptr<DomainDispatcher::WeakPtr> backend_impl,int call_id,span<uint8_t> method,span<uint8_t> message)414 DomainDispatcher::Callback::Callback(
415     std::unique_ptr<DomainDispatcher::WeakPtr> backend_impl,
416     int call_id,
417     span<uint8_t> method,
418     span<uint8_t> message)
419     : backend_impl_(std::move(backend_impl)),
420       call_id_(call_id),
421       method_(method),
422       message_(message.begin(), message.end()) {}
423 
sendIfActive(std::unique_ptr<Serializable> partialMessage,const DispatchResponse & response)424 void DomainDispatcher::Callback::sendIfActive(
425     std::unique_ptr<Serializable> partialMessage,
426     const DispatchResponse& response) {
427   if (!backend_impl_ || !backend_impl_->get())
428     return;
429   backend_impl_->get()->sendResponse(call_id_, response,
430                                      std::move(partialMessage));
431   backend_impl_ = nullptr;
432 }
433 
fallThroughIfActive()434 void DomainDispatcher::Callback::fallThroughIfActive() {
435   if (!backend_impl_ || !backend_impl_->get())
436     return;
437   backend_impl_->get()->channel()->FallThrough(call_id_, method_,
438                                                SpanFrom(message_));
439   backend_impl_ = nullptr;
440 }
441 
DomainDispatcher(FrontendChannel * frontendChannel)442 DomainDispatcher::DomainDispatcher(FrontendChannel* frontendChannel)
443     : frontend_channel_(frontendChannel) {}
444 
~DomainDispatcher()445 DomainDispatcher::~DomainDispatcher() {
446   clearFrontend();
447 }
448 
sendResponse(int call_id,const DispatchResponse & response,std::unique_ptr<Serializable> result)449 void DomainDispatcher::sendResponse(int call_id,
450                                     const DispatchResponse& response,
451                                     std::unique_ptr<Serializable> result) {
452   if (!frontend_channel_)
453     return;
454   std::unique_ptr<Serializable> serializable;
455   if (response.IsError()) {
456     serializable = CreateErrorResponse(call_id, response);
457   } else {
458     serializable = CreateResponse(call_id, std::move(result));
459   }
460   frontend_channel_->SendProtocolResponse(call_id, std::move(serializable));
461 }
462 
MaybeReportInvalidParams(const Dispatchable & dispatchable,const ErrorSupport & errors)463 bool DomainDispatcher::MaybeReportInvalidParams(
464     const Dispatchable& dispatchable,
465     const ErrorSupport& errors) {
466   if (errors.Errors().empty())
467     return false;
468   if (frontend_channel_) {
469     frontend_channel_->SendProtocolResponse(
470         dispatchable.CallId(),
471         CreateErrorResponse(
472             dispatchable.CallId(),
473             DispatchResponse::InvalidParams("Invalid parameters"), &errors));
474   }
475   return true;
476 }
477 
clearFrontend()478 void DomainDispatcher::clearFrontend() {
479   frontend_channel_ = nullptr;
480   for (auto& weak : weak_ptrs_)
481     weak->dispose();
482   weak_ptrs_.clear();
483 }
484 
weakPtr()485 std::unique_ptr<DomainDispatcher::WeakPtr> DomainDispatcher::weakPtr() {
486   auto weak = std::make_unique<DomainDispatcher::WeakPtr>(this);
487   weak_ptrs_.insert(weak.get());
488   return weak;
489 }
490 
491 // =============================================================================
492 // UberDispatcher - dispatches between domains (backends).
493 // =============================================================================
DispatchResult(bool method_found,std::function<void ()> runnable)494 UberDispatcher::DispatchResult::DispatchResult(bool method_found,
495                                                std::function<void()> runnable)
496     : method_found_(method_found), runnable_(runnable) {}
497 
Run()498 void UberDispatcher::DispatchResult::Run() {
499   if (!runnable_)
500     return;
501   runnable_();
502   runnable_ = nullptr;
503 }
504 
UberDispatcher(FrontendChannel * frontend_channel)505 UberDispatcher::UberDispatcher(FrontendChannel* frontend_channel)
506     : frontend_channel_(frontend_channel) {
507   assert(frontend_channel);
508 }
509 
510 UberDispatcher::~UberDispatcher() = default;
511 
512 constexpr size_t kNotFound = std::numeric_limits<size_t>::max();
513 
514 namespace {
DotIdx(span<uint8_t> method)515 size_t DotIdx(span<uint8_t> method) {
516   const void* p = memchr(method.data(), '.', method.size());
517   return p ? reinterpret_cast<const uint8_t*>(p) - method.data() : kNotFound;
518 }
519 }  // namespace
520 
Dispatch(const Dispatchable & dispatchable) const521 UberDispatcher::DispatchResult UberDispatcher::Dispatch(
522     const Dispatchable& dispatchable) const {
523   span<uint8_t> method = FindByFirst(redirects_, dispatchable.Method(),
524                                      /*default_value=*/dispatchable.Method());
525   size_t dot_idx = DotIdx(method);
526   if (dot_idx != kNotFound) {
527     span<uint8_t> domain = method.subspan(0, dot_idx);
528     span<uint8_t> command = method.subspan(dot_idx + 1);
529     DomainDispatcher* dispatcher = FindByFirst(dispatchers_, domain);
530     if (dispatcher) {
531       std::function<void(const Dispatchable&)> dispatched =
532           dispatcher->Dispatch(command);
533       if (dispatched) {
534         return DispatchResult(
535             true, [dispatchable, dispatched = std::move(dispatched)]() {
536               dispatched(dispatchable);
537             });
538       }
539     }
540   }
541   return DispatchResult(false, [this, dispatchable]() {
542     frontend_channel_->SendProtocolResponse(
543         dispatchable.CallId(),
544         CreateErrorResponse(dispatchable.CallId(),
545                             DispatchResponse::MethodNotFound(
546                                 "'" +
547                                 std::string(dispatchable.Method().begin(),
548                                             dispatchable.Method().end()) +
549                                 "' wasn't found")));
550   });
551 }
552 
553 template <typename T>
554 struct FirstLessThan {
operator ()crdtp::FirstLessThan555   bool operator()(const std::pair<span<uint8_t>, T>& left,
556                   const std::pair<span<uint8_t>, T>& right) {
557     return SpanLessThan(left.first, right.first);
558   }
559 };
560 
WireBackend(span<uint8_t> domain,const std::vector<std::pair<span<uint8_t>,span<uint8_t>>> & sorted_redirects,std::unique_ptr<DomainDispatcher> dispatcher)561 void UberDispatcher::WireBackend(
562     span<uint8_t> domain,
563     const std::vector<std::pair<span<uint8_t>, span<uint8_t>>>&
564         sorted_redirects,
565     std::unique_ptr<DomainDispatcher> dispatcher) {
566   auto it = redirects_.insert(redirects_.end(), sorted_redirects.begin(),
567                               sorted_redirects.end());
568   std::inplace_merge(redirects_.begin(), it, redirects_.end(),
569                      FirstLessThan<span<uint8_t>>());
570   auto jt = dispatchers_.insert(dispatchers_.end(),
571                                 std::make_pair(domain, std::move(dispatcher)));
572   std::inplace_merge(dispatchers_.begin(), jt, dispatchers_.end(),
573                      FirstLessThan<std::unique_ptr<DomainDispatcher>>());
574 }
575 
576 }  // namespace crdtp
577