1 //===- RPCUtils.h - Utilities for building RPC APIs -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utilities to support construction of simple RPC APIs.
10 //
11 // The RPC utilities aim for ease of use (minimal conceptual overhead) for C++
12 // programmers, high performance, low memory overhead, and efficient use of the
13 // communications channel.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
18 #define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
19 
20 #include <map>
21 #include <thread>
22 #include <vector>
23 
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ExecutionEngine/Orc/OrcError.h"
26 #include "llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h"
27 #include "llvm/Support/MSVCErrorWorkarounds.h"
28 
29 #include <future>
30 
31 namespace llvm {
32 namespace orc {
33 namespace rpc {
34 
35 /// Base class of all fatal RPC errors (those that necessarily result in the
36 /// termination of the RPC session).
37 class RPCFatalError : public ErrorInfo<RPCFatalError> {
38 public:
39   static char ID;
40 };
41 
42 /// RPCConnectionClosed is returned from RPC operations if the RPC connection
43 /// has already been closed due to either an error or graceful disconnection.
44 class ConnectionClosed : public ErrorInfo<ConnectionClosed> {
45 public:
46   static char ID;
47   std::error_code convertToErrorCode() const override;
48   void log(raw_ostream &OS) const override;
49 };
50 
51 /// BadFunctionCall is returned from handleOne when the remote makes a call with
52 /// an unrecognized function id.
53 ///
54 /// This error is fatal because Orc RPC needs to know how to parse a function
55 /// call to know where the next call starts, and if it doesn't recognize the
56 /// function id it cannot parse the call.
57 template <typename FnIdT, typename SeqNoT>
58 class BadFunctionCall
59   : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> {
60 public:
61   static char ID;
62 
BadFunctionCall(FnIdT FnId,SeqNoT SeqNo)63   BadFunctionCall(FnIdT FnId, SeqNoT SeqNo)
64       : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {}
65 
convertToErrorCode()66   std::error_code convertToErrorCode() const override {
67     return orcError(OrcErrorCode::UnexpectedRPCCall);
68   }
69 
log(raw_ostream & OS)70   void log(raw_ostream &OS) const override {
71     OS << "Call to invalid RPC function id '" << FnId << "' with "
72           "sequence number " << SeqNo;
73   }
74 
75 private:
76   FnIdT FnId;
77   SeqNoT SeqNo;
78 };
79 
80 template <typename FnIdT, typename SeqNoT>
81 char BadFunctionCall<FnIdT, SeqNoT>::ID = 0;
82 
83 /// InvalidSequenceNumberForResponse is returned from handleOne when a response
84 /// call arrives with a sequence number that doesn't correspond to any in-flight
85 /// function call.
86 ///
87 /// This error is fatal because Orc RPC needs to know how to parse the rest of
88 /// the response call to know where the next call starts, and if it doesn't have
89 /// a result parser for this sequence number it can't do that.
90 template <typename SeqNoT>
91 class InvalidSequenceNumberForResponse
92     : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> {
93 public:
94   static char ID;
95 
InvalidSequenceNumberForResponse(SeqNoT SeqNo)96   InvalidSequenceNumberForResponse(SeqNoT SeqNo)
97       : SeqNo(std::move(SeqNo)) {}
98 
convertToErrorCode()99   std::error_code convertToErrorCode() const override {
100     return orcError(OrcErrorCode::UnexpectedRPCCall);
101   };
102 
log(raw_ostream & OS)103   void log(raw_ostream &OS) const override {
104     OS << "Response has unknown sequence number " << SeqNo;
105   }
106 private:
107   SeqNoT SeqNo;
108 };
109 
110 template <typename SeqNoT>
111 char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0;
112 
113 /// This non-fatal error will be passed to asynchronous result handlers in place
114 /// of a result if the connection goes down before a result returns, or if the
115 /// function to be called cannot be negotiated with the remote.
116 class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> {
117 public:
118   static char ID;
119 
120   std::error_code convertToErrorCode() const override;
121   void log(raw_ostream &OS) const override;
122 };
123 
124 /// This error is returned if the remote does not have a handler installed for
125 /// the given RPC function.
126 class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> {
127 public:
128   static char ID;
129 
130   CouldNotNegotiate(std::string Signature);
131   std::error_code convertToErrorCode() const override;
132   void log(raw_ostream &OS) const override;
getSignature()133   const std::string &getSignature() const { return Signature; }
134 private:
135   std::string Signature;
136 };
137 
138 template <typename DerivedFunc, typename FnT> class Function;
139 
140 // RPC Function class.
141 // DerivedFunc should be a user defined class with a static 'getName()' method
142 // returning a const char* representing the function's name.
143 template <typename DerivedFunc, typename RetT, typename... ArgTs>
144 class Function<DerivedFunc, RetT(ArgTs...)> {
145 public:
146   /// User defined function type.
147   using Type = RetT(ArgTs...);
148 
149   /// Return type.
150   using ReturnType = RetT;
151 
152   /// Returns the full function prototype as a string.
getPrototype()153   static const char *getPrototype() {
154     static std::string Name = [] {
155       std::string Name;
156       raw_string_ostream(Name)
157           << RPCTypeName<RetT>::getName() << " " << DerivedFunc::getName()
158           << "(" << llvm::orc::rpc::RPCTypeNameSequence<ArgTs...>() << ")";
159       return Name;
160     }();
161     return Name.data();
162   }
163 };
164 
165 /// Allocates RPC function ids during autonegotiation.
166 /// Specializations of this class must provide four members:
167 ///
168 /// static T getInvalidId():
169 ///   Should return a reserved id that will be used to represent missing
170 /// functions during autonegotiation.
171 ///
172 /// static T getResponseId():
173 ///   Should return a reserved id that will be used to send function responses
174 /// (return values).
175 ///
176 /// static T getNegotiateId():
177 ///   Should return a reserved id for the negotiate function, which will be used
178 /// to negotiate ids for user defined functions.
179 ///
180 /// template <typename Func> T allocate():
181 ///   Allocate a unique id for function Func.
182 template <typename T, typename = void> class RPCFunctionIdAllocator;
183 
184 /// This specialization of RPCFunctionIdAllocator provides a default
185 /// implementation for integral types.
186 template <typename T>
187 class RPCFunctionIdAllocator<
188     T, typename std::enable_if<std::is_integral<T>::value>::type> {
189 public:
getInvalidId()190   static T getInvalidId() { return T(0); }
getResponseId()191   static T getResponseId() { return T(1); }
getNegotiateId()192   static T getNegotiateId() { return T(2); }
193 
allocate()194   template <typename Func> T allocate() { return NextId++; }
195 
196 private:
197   T NextId = 3;
198 };
199 
200 namespace detail {
201 
202 /// Provides a typedef for a tuple containing the decayed argument types.
203 template <typename T> class FunctionArgsTuple;
204 
205 template <typename RetT, typename... ArgTs>
206 class FunctionArgsTuple<RetT(ArgTs...)> {
207 public:
208   using Type = std::tuple<typename std::decay<
209       typename std::remove_reference<ArgTs>::type>::type...>;
210 };
211 
212 // ResultTraits provides typedefs and utilities specific to the return type
213 // of functions.
214 template <typename RetT> class ResultTraits {
215 public:
216   // The return type wrapped in llvm::Expected.
217   using ErrorReturnType = Expected<RetT>;
218 
219 #ifdef _MSC_VER
220   // The ErrorReturnType wrapped in a std::promise.
221   using ReturnPromiseType = std::promise<MSVCPExpected<RetT>>;
222 
223   // The ErrorReturnType wrapped in a std::future.
224   using ReturnFutureType = std::future<MSVCPExpected<RetT>>;
225 #else
226   // The ErrorReturnType wrapped in a std::promise.
227   using ReturnPromiseType = std::promise<ErrorReturnType>;
228 
229   // The ErrorReturnType wrapped in a std::future.
230   using ReturnFutureType = std::future<ErrorReturnType>;
231 #endif
232 
233   // Create a 'blank' value of the ErrorReturnType, ready and safe to
234   // overwrite.
createBlankErrorReturnValue()235   static ErrorReturnType createBlankErrorReturnValue() {
236     return ErrorReturnType(RetT());
237   }
238 
239   // Consume an abandoned ErrorReturnType.
consumeAbandoned(ErrorReturnType RetOrErr)240   static void consumeAbandoned(ErrorReturnType RetOrErr) {
241     consumeError(RetOrErr.takeError());
242   }
243 };
244 
245 // ResultTraits specialization for void functions.
246 template <> class ResultTraits<void> {
247 public:
248   // For void functions, ErrorReturnType is llvm::Error.
249   using ErrorReturnType = Error;
250 
251 #ifdef _MSC_VER
252   // The ErrorReturnType wrapped in a std::promise.
253   using ReturnPromiseType = std::promise<MSVCPError>;
254 
255   // The ErrorReturnType wrapped in a std::future.
256   using ReturnFutureType = std::future<MSVCPError>;
257 #else
258   // The ErrorReturnType wrapped in a std::promise.
259   using ReturnPromiseType = std::promise<ErrorReturnType>;
260 
261   // The ErrorReturnType wrapped in a std::future.
262   using ReturnFutureType = std::future<ErrorReturnType>;
263 #endif
264 
265   // Create a 'blank' value of the ErrorReturnType, ready and safe to
266   // overwrite.
createBlankErrorReturnValue()267   static ErrorReturnType createBlankErrorReturnValue() {
268     return ErrorReturnType::success();
269   }
270 
271   // Consume an abandoned ErrorReturnType.
consumeAbandoned(ErrorReturnType Err)272   static void consumeAbandoned(ErrorReturnType Err) {
273     consumeError(std::move(Err));
274   }
275 };
276 
277 // ResultTraits<Error> is equivalent to ResultTraits<void>. This allows
278 // handlers for void RPC functions to return either void (in which case they
279 // implicitly succeed) or Error (in which case their error return is
280 // propagated). See usage in HandlerTraits::runHandlerHelper.
281 template <> class ResultTraits<Error> : public ResultTraits<void> {};
282 
283 // ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows
284 // handlers for RPC functions returning a T to return either a T (in which
285 // case they implicitly succeed) or Expected<T> (in which case their error
286 // return is propagated). See usage in HandlerTraits::runHandlerHelper.
287 template <typename RetT>
288 class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {};
289 
290 // Determines whether an RPC function's defined error return type supports
291 // error return value.
292 template <typename T>
293 class SupportsErrorReturn {
294 public:
295   static const bool value = false;
296 };
297 
298 template <>
299 class SupportsErrorReturn<Error> {
300 public:
301   static const bool value = true;
302 };
303 
304 template <typename T>
305 class SupportsErrorReturn<Expected<T>> {
306 public:
307   static const bool value = true;
308 };
309 
310 // RespondHelper packages return values based on whether or not the declared
311 // RPC function return type supports error returns.
312 template <bool FuncSupportsErrorReturn>
313 class RespondHelper;
314 
315 // RespondHelper specialization for functions that support error returns.
316 template <>
317 class RespondHelper<true> {
318 public:
319 
320   // Send Expected<T>.
321   template <typename WireRetT, typename HandlerRetT, typename ChannelT,
322             typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Expected<HandlerRetT> ResultOrErr)323   static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
324                           SequenceNumberT SeqNo,
325                           Expected<HandlerRetT> ResultOrErr) {
326     if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>())
327       return ResultOrErr.takeError();
328 
329     // Open the response message.
330     if (auto Err = C.startSendMessage(ResponseId, SeqNo))
331       return Err;
332 
333     // Serialize the result.
334     if (auto Err =
335         SerializationTraits<ChannelT, WireRetT,
336                             Expected<HandlerRetT>>::serialize(
337                                                      C, std::move(ResultOrErr)))
338       return Err;
339 
340     // Close the response message.
341     if (auto Err = C.endSendMessage())
342       return Err;
343     return C.send();
344   }
345 
346   template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Error Err)347   static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
348                           SequenceNumberT SeqNo, Error Err) {
349     if (Err && Err.isA<RPCFatalError>())
350       return Err;
351     if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
352       return Err2;
353     if (auto Err2 = serializeSeq(C, std::move(Err)))
354       return Err2;
355     if (auto Err2 = C.endSendMessage())
356       return Err2;
357     return C.send();
358   }
359 
360 };
361 
362 // RespondHelper specialization for functions that do not support error returns.
363 template <>
364 class RespondHelper<false> {
365 public:
366 
367   template <typename WireRetT, typename HandlerRetT, typename ChannelT,
368             typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Expected<HandlerRetT> ResultOrErr)369   static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
370                           SequenceNumberT SeqNo,
371                           Expected<HandlerRetT> ResultOrErr) {
372     if (auto Err = ResultOrErr.takeError())
373       return Err;
374 
375     // Open the response message.
376     if (auto Err = C.startSendMessage(ResponseId, SeqNo))
377       return Err;
378 
379     // Serialize the result.
380     if (auto Err =
381         SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
382                                                                C, *ResultOrErr))
383       return Err;
384 
385     // End the response message.
386     if (auto Err = C.endSendMessage())
387       return Err;
388 
389     return C.send();
390   }
391 
392   template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Error Err)393   static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
394                           SequenceNumberT SeqNo, Error Err) {
395     if (Err)
396       return Err;
397     if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
398       return Err2;
399     if (auto Err2 = C.endSendMessage())
400       return Err2;
401     return C.send();
402   }
403 
404 };
405 
406 
407 // Send a response of the given wire return type (WireRetT) over the
408 // channel, with the given sequence number.
409 template <typename WireRetT, typename HandlerRetT, typename ChannelT,
410           typename FunctionIdT, typename SequenceNumberT>
respond(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Expected<HandlerRetT> ResultOrErr)411 Error respond(ChannelT &C, const FunctionIdT &ResponseId,
412               SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) {
413   return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
414     template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr));
415 }
416 
417 // Send an empty response message on the given channel to indicate that
418 // the handler ran.
419 template <typename WireRetT, typename ChannelT, typename FunctionIdT,
420           typename SequenceNumberT>
respond(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Error Err)421 Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
422               Error Err) {
423   return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
424     sendResult(C, ResponseId, SeqNo, std::move(Err));
425 }
426 
427 // Converts a given type to the equivalent error return type.
428 template <typename T> class WrappedHandlerReturn {
429 public:
430   using Type = Expected<T>;
431 };
432 
433 template <typename T> class WrappedHandlerReturn<Expected<T>> {
434 public:
435   using Type = Expected<T>;
436 };
437 
438 template <> class WrappedHandlerReturn<void> {
439 public:
440   using Type = Error;
441 };
442 
443 template <> class WrappedHandlerReturn<Error> {
444 public:
445   using Type = Error;
446 };
447 
448 template <> class WrappedHandlerReturn<ErrorSuccess> {
449 public:
450   using Type = Error;
451 };
452 
453 // Traits class that strips the response function from the list of handler
454 // arguments.
455 template <typename FnT> class AsyncHandlerTraits;
456 
457 template <typename ResultT, typename... ArgTs>
458 class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> {
459 public:
460   using Type = Error(ArgTs...);
461   using ResultType = Expected<ResultT>;
462 };
463 
464 template <typename... ArgTs>
465 class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> {
466 public:
467   using Type = Error(ArgTs...);
468   using ResultType = Error;
469 };
470 
471 template <typename... ArgTs>
472 class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> {
473 public:
474   using Type = Error(ArgTs...);
475   using ResultType = Error;
476 };
477 
478 template <typename... ArgTs>
479 class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> {
480 public:
481   using Type = Error(ArgTs...);
482   using ResultType = Error;
483 };
484 
485 template <typename ResponseHandlerT, typename... ArgTs>
486 class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> :
487     public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type,
488                                     ArgTs...)> {};
489 
490 // This template class provides utilities related to RPC function handlers.
491 // The base case applies to non-function types (the template class is
492 // specialized for function types) and inherits from the appropriate
493 // speciilization for the given non-function type's call operator.
494 template <typename HandlerT>
495 class HandlerTraits : public HandlerTraits<decltype(
496                           &std::remove_reference<HandlerT>::type::operator())> {
497 };
498 
499 // Traits for handlers with a given function type.
500 template <typename RetT, typename... ArgTs>
501 class HandlerTraits<RetT(ArgTs...)> {
502 public:
503   // Function type of the handler.
504   using Type = RetT(ArgTs...);
505 
506   // Return type of the handler.
507   using ReturnType = RetT;
508 
509   // Call the given handler with the given arguments.
510   template <typename HandlerT, typename... TArgTs>
511   static typename WrappedHandlerReturn<RetT>::Type
unpackAndRun(HandlerT & Handler,std::tuple<TArgTs...> & Args)512   unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) {
513     return unpackAndRunHelper(Handler, Args,
514                               std::index_sequence_for<TArgTs...>());
515   }
516 
517   // Call the given handler with the given arguments.
518   template <typename HandlerT, typename ResponderT, typename... TArgTs>
unpackAndRunAsync(HandlerT & Handler,ResponderT & Responder,std::tuple<TArgTs...> & Args)519   static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder,
520                                  std::tuple<TArgTs...> &Args) {
521     return unpackAndRunAsyncHelper(Handler, Responder, Args,
522                                    std::index_sequence_for<TArgTs...>());
523   }
524 
525   // Call the given handler with the given arguments.
526   template <typename HandlerT>
527   static typename std::enable_if<
528       std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
529       Error>::type
run(HandlerT & Handler,ArgTs &&...Args)530   run(HandlerT &Handler, ArgTs &&... Args) {
531     Handler(std::move(Args)...);
532     return Error::success();
533   }
534 
535   template <typename HandlerT, typename... TArgTs>
536   static typename std::enable_if<
537       !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
538       typename HandlerTraits<HandlerT>::ReturnType>::type
run(HandlerT & Handler,TArgTs...Args)539   run(HandlerT &Handler, TArgTs... Args) {
540     return Handler(std::move(Args)...);
541   }
542 
543   // Serialize arguments to the channel.
544   template <typename ChannelT, typename... CArgTs>
serializeArgs(ChannelT & C,const CArgTs...CArgs)545   static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) {
546     return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...);
547   }
548 
549   // Deserialize arguments from the channel.
550   template <typename ChannelT, typename... CArgTs>
deserializeArgs(ChannelT & C,std::tuple<CArgTs...> & Args)551   static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) {
552     return deserializeArgsHelper(C, Args, std::index_sequence_for<CArgTs...>());
553   }
554 
555 private:
556   template <typename ChannelT, typename... CArgTs, size_t... Indexes>
deserializeArgsHelper(ChannelT & C,std::tuple<CArgTs...> & Args,std::index_sequence<Indexes...> _)557   static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args,
558                                      std::index_sequence<Indexes...> _) {
559     return SequenceSerialization<ChannelT, ArgTs...>::deserialize(
560         C, std::get<Indexes>(Args)...);
561   }
562 
563   template <typename HandlerT, typename ArgTuple, size_t... Indexes>
564   static typename WrappedHandlerReturn<
565       typename HandlerTraits<HandlerT>::ReturnType>::Type
unpackAndRunHelper(HandlerT & Handler,ArgTuple & Args,std::index_sequence<Indexes...>)566   unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args,
567                      std::index_sequence<Indexes...>) {
568     return run(Handler, std::move(std::get<Indexes>(Args))...);
569   }
570 
571   template <typename HandlerT, typename ResponderT, typename ArgTuple,
572             size_t... Indexes>
573   static typename WrappedHandlerReturn<
574       typename HandlerTraits<HandlerT>::ReturnType>::Type
unpackAndRunAsyncHelper(HandlerT & Handler,ResponderT & Responder,ArgTuple & Args,std::index_sequence<Indexes...>)575   unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder,
576                           ArgTuple &Args, std::index_sequence<Indexes...>) {
577     return run(Handler, Responder, std::move(std::get<Indexes>(Args))...);
578   }
579 };
580 
581 // Handler traits for free functions.
582 template <typename RetT, typename... ArgTs>
583 class HandlerTraits<RetT(*)(ArgTs...)>
584   : public HandlerTraits<RetT(ArgTs...)> {};
585 
586 // Handler traits for class methods (especially call operators for lambdas).
587 template <typename Class, typename RetT, typename... ArgTs>
588 class HandlerTraits<RetT (Class::*)(ArgTs...)>
589     : public HandlerTraits<RetT(ArgTs...)> {};
590 
591 // Handler traits for const class methods (especially call operators for
592 // lambdas).
593 template <typename Class, typename RetT, typename... ArgTs>
594 class HandlerTraits<RetT (Class::*)(ArgTs...) const>
595     : public HandlerTraits<RetT(ArgTs...)> {};
596 
597 // Utility to peel the Expected wrapper off a response handler error type.
598 template <typename HandlerT> class ResponseHandlerArg;
599 
600 template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> {
601 public:
602   using ArgType = Expected<ArgT>;
603   using UnwrappedArgType = ArgT;
604 };
605 
606 template <typename ArgT>
607 class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> {
608 public:
609   using ArgType = Expected<ArgT>;
610   using UnwrappedArgType = ArgT;
611 };
612 
613 template <> class ResponseHandlerArg<Error(Error)> {
614 public:
615   using ArgType = Error;
616 };
617 
618 template <> class ResponseHandlerArg<ErrorSuccess(Error)> {
619 public:
620   using ArgType = Error;
621 };
622 
623 // ResponseHandler represents a handler for a not-yet-received function call
624 // result.
625 template <typename ChannelT> class ResponseHandler {
626 public:
~ResponseHandler()627   virtual ~ResponseHandler() {}
628 
629   // Reads the function result off the wire and acts on it. The meaning of
630   // "act" will depend on how this method is implemented in any given
631   // ResponseHandler subclass but could, for example, mean running a
632   // user-specified handler or setting a promise value.
633   virtual Error handleResponse(ChannelT &C) = 0;
634 
635   // Abandons this outstanding result.
636   virtual void abandon() = 0;
637 
638   // Create an error instance representing an abandoned response.
createAbandonedResponseError()639   static Error createAbandonedResponseError() {
640     return make_error<ResponseAbandoned>();
641   }
642 };
643 
644 // ResponseHandler subclass for RPC functions with non-void returns.
645 template <typename ChannelT, typename FuncRetT, typename HandlerT>
646 class ResponseHandlerImpl : public ResponseHandler<ChannelT> {
647 public:
ResponseHandlerImpl(HandlerT Handler)648   ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
649 
650   // Handle the result by deserializing it from the channel then passing it
651   // to the user defined handler.
handleResponse(ChannelT & C)652   Error handleResponse(ChannelT &C) override {
653     using UnwrappedArgType = typename ResponseHandlerArg<
654         typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType;
655     UnwrappedArgType Result;
656     if (auto Err =
657             SerializationTraits<ChannelT, FuncRetT,
658                                 UnwrappedArgType>::deserialize(C, Result))
659       return Err;
660     if (auto Err = C.endReceiveMessage())
661       return Err;
662     return Handler(std::move(Result));
663   }
664 
665   // Abandon this response by calling the handler with an 'abandoned response'
666   // error.
abandon()667   void abandon() override {
668     if (auto Err = Handler(this->createAbandonedResponseError())) {
669       // Handlers should not fail when passed an abandoned response error.
670       report_fatal_error(std::move(Err));
671     }
672   }
673 
674 private:
675   HandlerT Handler;
676 };
677 
678 // ResponseHandler subclass for RPC functions with void returns.
679 template <typename ChannelT, typename HandlerT>
680 class ResponseHandlerImpl<ChannelT, void, HandlerT>
681     : public ResponseHandler<ChannelT> {
682 public:
ResponseHandlerImpl(HandlerT Handler)683   ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
684 
685   // Handle the result (no actual value, just a notification that the function
686   // has completed on the remote end) by calling the user-defined handler with
687   // Error::success().
handleResponse(ChannelT & C)688   Error handleResponse(ChannelT &C) override {
689     if (auto Err = C.endReceiveMessage())
690       return Err;
691     return Handler(Error::success());
692   }
693 
694   // Abandon this response by calling the handler with an 'abandoned response'
695   // error.
abandon()696   void abandon() override {
697     if (auto Err = Handler(this->createAbandonedResponseError())) {
698       // Handlers should not fail when passed an abandoned response error.
699       report_fatal_error(std::move(Err));
700     }
701   }
702 
703 private:
704   HandlerT Handler;
705 };
706 
707 template <typename ChannelT, typename FuncRetT, typename HandlerT>
708 class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT>
709     : public ResponseHandler<ChannelT> {
710 public:
ResponseHandlerImpl(HandlerT Handler)711   ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
712 
713   // Handle the result by deserializing it from the channel then passing it
714   // to the user defined handler.
handleResponse(ChannelT & C)715   Error handleResponse(ChannelT &C) override {
716     using HandlerArgType = typename ResponseHandlerArg<
717         typename HandlerTraits<HandlerT>::Type>::ArgType;
718     HandlerArgType Result((typename HandlerArgType::value_type()));
719 
720     if (auto Err =
721             SerializationTraits<ChannelT, Expected<FuncRetT>,
722                                 HandlerArgType>::deserialize(C, Result))
723       return Err;
724     if (auto Err = C.endReceiveMessage())
725       return Err;
726     return Handler(std::move(Result));
727   }
728 
729   // Abandon this response by calling the handler with an 'abandoned response'
730   // error.
abandon()731   void abandon() override {
732     if (auto Err = Handler(this->createAbandonedResponseError())) {
733       // Handlers should not fail when passed an abandoned response error.
734       report_fatal_error(std::move(Err));
735     }
736   }
737 
738 private:
739   HandlerT Handler;
740 };
741 
742 template <typename ChannelT, typename HandlerT>
743 class ResponseHandlerImpl<ChannelT, Error, HandlerT>
744     : public ResponseHandler<ChannelT> {
745 public:
ResponseHandlerImpl(HandlerT Handler)746   ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
747 
748   // Handle the result by deserializing it from the channel then passing it
749   // to the user defined handler.
handleResponse(ChannelT & C)750   Error handleResponse(ChannelT &C) override {
751     Error Result = Error::success();
752     if (auto Err = SerializationTraits<ChannelT, Error, Error>::deserialize(
753             C, Result)) {
754       consumeError(std::move(Result));
755       return Err;
756     }
757     if (auto Err = C.endReceiveMessage()) {
758       consumeError(std::move(Result));
759       return Err;
760     }
761     return Handler(std::move(Result));
762   }
763 
764   // Abandon this response by calling the handler with an 'abandoned response'
765   // error.
abandon()766   void abandon() override {
767     if (auto Err = Handler(this->createAbandonedResponseError())) {
768       // Handlers should not fail when passed an abandoned response error.
769       report_fatal_error(std::move(Err));
770     }
771   }
772 
773 private:
774   HandlerT Handler;
775 };
776 
777 // Create a ResponseHandler from a given user handler.
778 template <typename ChannelT, typename FuncRetT, typename HandlerT>
createResponseHandler(HandlerT H)779 std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) {
780   return std::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>(
781       std::move(H));
782 }
783 
784 // Helper for wrapping member functions up as functors. This is useful for
785 // installing methods as result handlers.
786 template <typename ClassT, typename RetT, typename... ArgTs>
787 class MemberFnWrapper {
788 public:
789   using MethodT = RetT (ClassT::*)(ArgTs...);
MemberFnWrapper(ClassT & Instance,MethodT Method)790   MemberFnWrapper(ClassT &Instance, MethodT Method)
791       : Instance(Instance), Method(Method) {}
operator()792   RetT operator()(ArgTs &&... Args) {
793     return (Instance.*Method)(std::move(Args)...);
794   }
795 
796 private:
797   ClassT &Instance;
798   MethodT Method;
799 };
800 
801 // Helper that provides a Functor for deserializing arguments.
802 template <typename... ArgTs> class ReadArgs {
803 public:
operator()804   Error operator()() { return Error::success(); }
805 };
806 
807 template <typename ArgT, typename... ArgTs>
808 class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> {
809 public:
ReadArgs(ArgT & Arg,ArgTs &...Args)810   ReadArgs(ArgT &Arg, ArgTs &... Args)
811       : ReadArgs<ArgTs...>(Args...), Arg(Arg) {}
812 
operator()813   Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) {
814     this->Arg = std::move(ArgVal);
815     return ReadArgs<ArgTs...>::operator()(ArgVals...);
816   }
817 
818 private:
819   ArgT &Arg;
820 };
821 
822 // Manage sequence numbers.
823 template <typename SequenceNumberT> class SequenceNumberManager {
824 public:
825   // Reset, making all sequence numbers available.
reset()826   void reset() {
827     std::lock_guard<std::mutex> Lock(SeqNoLock);
828     NextSequenceNumber = 0;
829     FreeSequenceNumbers.clear();
830   }
831 
832   // Get the next available sequence number. Will re-use numbers that have
833   // been released.
getSequenceNumber()834   SequenceNumberT getSequenceNumber() {
835     std::lock_guard<std::mutex> Lock(SeqNoLock);
836     if (FreeSequenceNumbers.empty())
837       return NextSequenceNumber++;
838     auto SequenceNumber = FreeSequenceNumbers.back();
839     FreeSequenceNumbers.pop_back();
840     return SequenceNumber;
841   }
842 
843   // Release a sequence number, making it available for re-use.
releaseSequenceNumber(SequenceNumberT SequenceNumber)844   void releaseSequenceNumber(SequenceNumberT SequenceNumber) {
845     std::lock_guard<std::mutex> Lock(SeqNoLock);
846     FreeSequenceNumbers.push_back(SequenceNumber);
847   }
848 
849 private:
850   std::mutex SeqNoLock;
851   SequenceNumberT NextSequenceNumber = 0;
852   std::vector<SequenceNumberT> FreeSequenceNumbers;
853 };
854 
855 // Checks that predicate P holds for each corresponding pair of type arguments
856 // from T1 and T2 tuple.
857 template <template <class, class> class P, typename T1Tuple, typename T2Tuple>
858 class RPCArgTypeCheckHelper;
859 
860 template <template <class, class> class P>
861 class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> {
862 public:
863   static const bool value = true;
864 };
865 
866 template <template <class, class> class P, typename T, typename... Ts,
867           typename U, typename... Us>
868 class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> {
869 public:
870   static const bool value =
871       P<T, U>::value &&
872       RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value;
873 };
874 
875 template <template <class, class> class P, typename T1Sig, typename T2Sig>
876 class RPCArgTypeCheck {
877 public:
878   using T1Tuple = typename FunctionArgsTuple<T1Sig>::Type;
879   using T2Tuple = typename FunctionArgsTuple<T2Sig>::Type;
880 
881   static_assert(std::tuple_size<T1Tuple>::value >=
882                     std::tuple_size<T2Tuple>::value,
883                 "Too many arguments to RPC call");
884   static_assert(std::tuple_size<T1Tuple>::value <=
885                     std::tuple_size<T2Tuple>::value,
886                 "Too few arguments to RPC call");
887 
888   static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value;
889 };
890 
891 template <typename ChannelT, typename WireT, typename ConcreteT>
892 class CanSerialize {
893 private:
894   using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
895 
896   template <typename T>
897   static std::true_type
898   check(typename std::enable_if<
899         std::is_same<decltype(T::serialize(std::declval<ChannelT &>(),
900                                            std::declval<const ConcreteT &>())),
901                      Error>::value,
902         void *>::type);
903 
904   template <typename> static std::false_type check(...);
905 
906 public:
907   static const bool value = decltype(check<S>(0))::value;
908 };
909 
910 template <typename ChannelT, typename WireT, typename ConcreteT>
911 class CanDeserialize {
912 private:
913   using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
914 
915   template <typename T>
916   static std::true_type
917   check(typename std::enable_if<
918         std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(),
919                                              std::declval<ConcreteT &>())),
920                      Error>::value,
921         void *>::type);
922 
923   template <typename> static std::false_type check(...);
924 
925 public:
926   static const bool value = decltype(check<S>(0))::value;
927 };
928 
929 /// Contains primitive utilities for defining, calling and handling calls to
930 /// remote procedures. ChannelT is a bidirectional stream conforming to the
931 /// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure
932 /// identifier type that must be serializable on ChannelT, and SequenceNumberT
933 /// is an integral type that will be used to number in-flight function calls.
934 ///
935 /// These utilities support the construction of very primitive RPC utilities.
936 /// Their intent is to ensure correct serialization and deserialization of
937 /// procedure arguments, and to keep the client and server's view of the API in
938 /// sync.
939 template <typename ImplT, typename ChannelT, typename FunctionIdT,
940           typename SequenceNumberT>
941 class RPCEndpointBase {
942 protected:
943   class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> {
944   public:
getName()945     static const char *getName() { return "__orc_rpc$invalid"; }
946   };
947 
948   class OrcRPCResponse : public Function<OrcRPCResponse, void()> {
949   public:
getName()950     static const char *getName() { return "__orc_rpc$response"; }
951   };
952 
953   class OrcRPCNegotiate
954       : public Function<OrcRPCNegotiate, FunctionIdT(std::string)> {
955   public:
getName()956     static const char *getName() { return "__orc_rpc$negotiate"; }
957   };
958 
959   // Helper predicate for testing for the presence of SerializeTraits
960   // serializers.
961   template <typename WireT, typename ConcreteT>
962   class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> {
963   public:
964     using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value;
965 
966     static_assert(value, "Missing serializer for argument (Can't serialize the "
967                          "first template type argument of CanSerializeCheck "
968                          "from the second)");
969   };
970 
971   // Helper predicate for testing for the presence of SerializeTraits
972   // deserializers.
973   template <typename WireT, typename ConcreteT>
974   class CanDeserializeCheck
975       : detail::CanDeserialize<ChannelT, WireT, ConcreteT> {
976   public:
977     using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value;
978 
979     static_assert(value, "Missing deserializer for argument (Can't deserialize "
980                          "the second template type argument of "
981                          "CanDeserializeCheck from the first)");
982   };
983 
984 public:
985   /// Construct an RPC instance on a channel.
RPCEndpointBase(ChannelT & C,bool LazyAutoNegotiation)986   RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation)
987       : C(C), LazyAutoNegotiation(LazyAutoNegotiation) {
988     // Hold ResponseId in a special variable, since we expect Response to be
989     // called relatively frequently, and want to avoid the map lookup.
990     ResponseId = FnIdAllocator.getResponseId();
991     RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId;
992 
993     // Register the negotiate function id and handler.
994     auto NegotiateId = FnIdAllocator.getNegotiateId();
995     RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId;
996     Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>(
997         [this](const std::string &Name) { return handleNegotiate(Name); });
998   }
999 
1000 
1001   /// Negotiate a function id for Func with the other end of the channel.
1002   template <typename Func> Error negotiateFunction(bool Retry = false) {
1003     return getRemoteFunctionId<Func>(true, Retry).takeError();
1004   }
1005 
1006   /// Append a call Func, does not call send on the channel.
1007   /// The first argument specifies a user-defined handler to be run when the
1008   /// function returns. The handler should take an Expected<Func::ReturnType>,
1009   /// or an Error (if Func::ReturnType is void). The handler will be called
1010   /// with an error if the return value is abandoned due to a channel error.
1011   template <typename Func, typename HandlerT, typename... ArgTs>
appendCallAsync(HandlerT Handler,const ArgTs &...Args)1012   Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) {
1013 
1014     static_assert(
1015         detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type,
1016                                 void(ArgTs...)>::value,
1017         "");
1018 
1019     // Look up the function ID.
1020     FunctionIdT FnId;
1021     if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false))
1022       FnId = *FnIdOrErr;
1023     else {
1024       // Negotiation failed. Notify the handler then return the negotiate-failed
1025       // error.
1026       cantFail(Handler(make_error<ResponseAbandoned>()));
1027       return FnIdOrErr.takeError();
1028     }
1029 
1030     SequenceNumberT SeqNo; // initialized in locked scope below.
1031     {
1032       // Lock the pending responses map and sequence number manager.
1033       std::lock_guard<std::mutex> Lock(ResponsesMutex);
1034 
1035       // Allocate a sequence number.
1036       SeqNo = SequenceNumberMgr.getSequenceNumber();
1037       assert(!PendingResponses.count(SeqNo) &&
1038              "Sequence number already allocated");
1039 
1040       // Install the user handler.
1041       PendingResponses[SeqNo] =
1042         detail::createResponseHandler<ChannelT, typename Func::ReturnType>(
1043             std::move(Handler));
1044     }
1045 
1046     // Open the function call message.
1047     if (auto Err = C.startSendMessage(FnId, SeqNo)) {
1048       abandonPendingResponses();
1049       return Err;
1050     }
1051 
1052     // Serialize the call arguments.
1053     if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs(
1054             C, Args...)) {
1055       abandonPendingResponses();
1056       return Err;
1057     }
1058 
1059     // Close the function call messagee.
1060     if (auto Err = C.endSendMessage()) {
1061       abandonPendingResponses();
1062       return Err;
1063     }
1064 
1065     return Error::success();
1066   }
1067 
sendAppendedCalls()1068   Error sendAppendedCalls() { return C.send(); };
1069 
1070   template <typename Func, typename HandlerT, typename... ArgTs>
callAsync(HandlerT Handler,const ArgTs &...Args)1071   Error callAsync(HandlerT Handler, const ArgTs &... Args) {
1072     if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...))
1073       return Err;
1074     return C.send();
1075   }
1076 
1077   /// Handle one incoming call.
handleOne()1078   Error handleOne() {
1079     FunctionIdT FnId;
1080     SequenceNumberT SeqNo;
1081     if (auto Err = C.startReceiveMessage(FnId, SeqNo)) {
1082       abandonPendingResponses();
1083       return Err;
1084     }
1085     if (FnId == ResponseId)
1086       return handleResponse(SeqNo);
1087     auto I = Handlers.find(FnId);
1088     if (I != Handlers.end())
1089       return I->second(C, SeqNo);
1090 
1091     // else: No handler found. Report error to client?
1092     return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId,
1093                                                                      SeqNo);
1094   }
1095 
1096   /// Helper for handling setter procedures - this method returns a functor that
1097   /// sets the variables referred to by Args... to values deserialized from the
1098   /// channel.
1099   /// E.g.
1100   ///
1101   ///   typedef Function<0, bool, int> Func1;
1102   ///
1103   ///   ...
1104   ///   bool B;
1105   ///   int I;
1106   ///   if (auto Err = expect<Func1>(Channel, readArgs(B, I)))
1107   ///     /* Handle Args */ ;
1108   ///
1109   template <typename... ArgTs>
readArgs(ArgTs &...Args)1110   static detail::ReadArgs<ArgTs...> readArgs(ArgTs &... Args) {
1111     return detail::ReadArgs<ArgTs...>(Args...);
1112   }
1113 
1114   /// Abandon all outstanding result handlers.
1115   ///
1116   /// This will call all currently registered result handlers to receive an
1117   /// "abandoned" error as their argument. This is used internally by the RPC
1118   /// in error situations, but can also be called directly by clients who are
1119   /// disconnecting from the remote and don't or can't expect responses to their
1120   /// outstanding calls. (Especially for outstanding blocking calls, calling
1121   /// this function may be necessary to avoid dead threads).
abandonPendingResponses()1122   void abandonPendingResponses() {
1123     // Lock the pending responses map and sequence number manager.
1124     std::lock_guard<std::mutex> Lock(ResponsesMutex);
1125 
1126     for (auto &KV : PendingResponses)
1127       KV.second->abandon();
1128     PendingResponses.clear();
1129     SequenceNumberMgr.reset();
1130   }
1131 
1132   /// Remove the handler for the given function.
1133   /// A handler must currently be registered for this function.
1134   template <typename Func>
removeHandler()1135   void removeHandler() {
1136     auto IdItr = LocalFunctionIds.find(Func::getPrototype());
1137     assert(IdItr != LocalFunctionIds.end() &&
1138            "Function does not have a registered handler");
1139     auto HandlerItr = Handlers.find(IdItr->second);
1140     assert(HandlerItr != Handlers.end() &&
1141            "Function does not have a registered handler");
1142     Handlers.erase(HandlerItr);
1143   }
1144 
1145   /// Clear all handlers.
clearHandlers()1146   void clearHandlers() {
1147     Handlers.clear();
1148   }
1149 
1150 protected:
1151 
getInvalidFunctionId()1152   FunctionIdT getInvalidFunctionId() const {
1153     return FnIdAllocator.getInvalidId();
1154   }
1155 
1156   /// Add the given handler to the handler map and make it available for
1157   /// autonegotiation and execution.
1158   template <typename Func, typename HandlerT>
addHandlerImpl(HandlerT Handler)1159   void addHandlerImpl(HandlerT Handler) {
1160 
1161     static_assert(detail::RPCArgTypeCheck<
1162                       CanDeserializeCheck, typename Func::Type,
1163                       typename detail::HandlerTraits<HandlerT>::Type>::value,
1164                   "");
1165 
1166     FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1167     LocalFunctionIds[Func::getPrototype()] = NewFnId;
1168     Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler));
1169   }
1170 
1171   template <typename Func, typename HandlerT>
addAsyncHandlerImpl(HandlerT Handler)1172   void addAsyncHandlerImpl(HandlerT Handler) {
1173 
1174     static_assert(detail::RPCArgTypeCheck<
1175                       CanDeserializeCheck, typename Func::Type,
1176                       typename detail::AsyncHandlerTraits<
1177                         typename detail::HandlerTraits<HandlerT>::Type
1178                       >::Type>::value,
1179                   "");
1180 
1181     FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1182     LocalFunctionIds[Func::getPrototype()] = NewFnId;
1183     Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler));
1184   }
1185 
handleResponse(SequenceNumberT SeqNo)1186   Error handleResponse(SequenceNumberT SeqNo) {
1187     using Handler = typename decltype(PendingResponses)::mapped_type;
1188     Handler PRHandler;
1189 
1190     {
1191       // Lock the pending responses map and sequence number manager.
1192       std::unique_lock<std::mutex> Lock(ResponsesMutex);
1193       auto I = PendingResponses.find(SeqNo);
1194 
1195       if (I != PendingResponses.end()) {
1196         PRHandler = std::move(I->second);
1197         PendingResponses.erase(I);
1198         SequenceNumberMgr.releaseSequenceNumber(SeqNo);
1199       } else {
1200         // Unlock the pending results map to prevent recursive lock.
1201         Lock.unlock();
1202         abandonPendingResponses();
1203         return make_error<
1204                  InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo);
1205       }
1206     }
1207 
1208     assert(PRHandler &&
1209            "If we didn't find a response handler we should have bailed out");
1210 
1211     if (auto Err = PRHandler->handleResponse(C)) {
1212       abandonPendingResponses();
1213       return Err;
1214     }
1215 
1216     return Error::success();
1217   }
1218 
handleNegotiate(const std::string & Name)1219   FunctionIdT handleNegotiate(const std::string &Name) {
1220     auto I = LocalFunctionIds.find(Name);
1221     if (I == LocalFunctionIds.end())
1222       return getInvalidFunctionId();
1223     return I->second;
1224   }
1225 
1226   // Find the remote FunctionId for the given function.
1227   template <typename Func>
getRemoteFunctionId(bool NegotiateIfNotInMap,bool NegotiateIfInvalid)1228   Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap,
1229                                             bool NegotiateIfInvalid) {
1230     bool DoNegotiate;
1231 
1232     // Check if we already have a function id...
1233     auto I = RemoteFunctionIds.find(Func::getPrototype());
1234     if (I != RemoteFunctionIds.end()) {
1235       // If it's valid there's nothing left to do.
1236       if (I->second != getInvalidFunctionId())
1237         return I->second;
1238       DoNegotiate = NegotiateIfInvalid;
1239     } else
1240       DoNegotiate = NegotiateIfNotInMap;
1241 
1242     // We don't have a function id for Func yet, but we're allowed to try to
1243     // negotiate one.
1244     if (DoNegotiate) {
1245       auto &Impl = static_cast<ImplT &>(*this);
1246       if (auto RemoteIdOrErr =
1247           Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) {
1248         RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
1249         if (*RemoteIdOrErr == getInvalidFunctionId())
1250           return make_error<CouldNotNegotiate>(Func::getPrototype());
1251         return *RemoteIdOrErr;
1252       } else
1253         return RemoteIdOrErr.takeError();
1254     }
1255 
1256     // No key was available in the map and we weren't allowed to try to
1257     // negotiate one, so return an unknown function error.
1258     return make_error<CouldNotNegotiate>(Func::getPrototype());
1259   }
1260 
1261   using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>;
1262 
1263   // Wrap the given user handler in the necessary argument-deserialization code,
1264   // result-serialization code, and call to the launch policy (if present).
1265   template <typename Func, typename HandlerT>
wrapHandler(HandlerT Handler)1266   WrappedHandlerFn wrapHandler(HandlerT Handler) {
1267     return [this, Handler](ChannelT &Channel,
1268                            SequenceNumberT SeqNo) mutable -> Error {
1269       // Start by deserializing the arguments.
1270       using ArgsTuple =
1271           typename detail::FunctionArgsTuple<
1272             typename detail::HandlerTraits<HandlerT>::Type>::Type;
1273       auto Args = std::make_shared<ArgsTuple>();
1274 
1275       if (auto Err =
1276               detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1277                   Channel, *Args))
1278         return Err;
1279 
1280       // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1281       // for RPCArgs. Void cast RPCArgs to work around this for now.
1282       // FIXME: Remove this workaround once we can assume a working GCC version.
1283       (void)Args;
1284 
1285       // End receieve message, unlocking the channel for reading.
1286       if (auto Err = Channel.endReceiveMessage())
1287         return Err;
1288 
1289       using HTraits = detail::HandlerTraits<HandlerT>;
1290       using FuncReturn = typename Func::ReturnType;
1291       return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo,
1292                                          HTraits::unpackAndRun(Handler, *Args));
1293     };
1294   }
1295 
1296   // Wrap the given user handler in the necessary argument-deserialization code,
1297   // result-serialization code, and call to the launch policy (if present).
1298   template <typename Func, typename HandlerT>
wrapAsyncHandler(HandlerT Handler)1299   WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) {
1300     return [this, Handler](ChannelT &Channel,
1301                            SequenceNumberT SeqNo) mutable -> Error {
1302       // Start by deserializing the arguments.
1303       using AHTraits = detail::AsyncHandlerTraits<
1304                          typename detail::HandlerTraits<HandlerT>::Type>;
1305       using ArgsTuple =
1306           typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type;
1307       auto Args = std::make_shared<ArgsTuple>();
1308 
1309       if (auto Err =
1310               detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1311                   Channel, *Args))
1312         return Err;
1313 
1314       // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1315       // for RPCArgs. Void cast RPCArgs to work around this for now.
1316       // FIXME: Remove this workaround once we can assume a working GCC version.
1317       (void)Args;
1318 
1319       // End receieve message, unlocking the channel for reading.
1320       if (auto Err = Channel.endReceiveMessage())
1321         return Err;
1322 
1323       using HTraits = detail::HandlerTraits<HandlerT>;
1324       using FuncReturn = typename Func::ReturnType;
1325       auto Responder =
1326         [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error {
1327           return detail::respond<FuncReturn>(C, ResponseId, SeqNo,
1328                                              std::move(RetVal));
1329         };
1330 
1331       return HTraits::unpackAndRunAsync(Handler, Responder, *Args);
1332     };
1333   }
1334 
1335   ChannelT &C;
1336 
1337   bool LazyAutoNegotiation;
1338 
1339   RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator;
1340 
1341   FunctionIdT ResponseId;
1342   std::map<std::string, FunctionIdT> LocalFunctionIds;
1343   std::map<const char *, FunctionIdT> RemoteFunctionIds;
1344 
1345   std::map<FunctionIdT, WrappedHandlerFn> Handlers;
1346 
1347   std::mutex ResponsesMutex;
1348   detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr;
1349   std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>>
1350       PendingResponses;
1351 };
1352 
1353 } // end namespace detail
1354 
1355 template <typename ChannelT, typename FunctionIdT = uint32_t,
1356           typename SequenceNumberT = uint32_t>
1357 class MultiThreadedRPCEndpoint
1358     : public detail::RPCEndpointBase<
1359           MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1360           ChannelT, FunctionIdT, SequenceNumberT> {
1361 private:
1362   using BaseClass =
1363       detail::RPCEndpointBase<
1364         MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1365         ChannelT, FunctionIdT, SequenceNumberT>;
1366 
1367 public:
MultiThreadedRPCEndpoint(ChannelT & C,bool LazyAutoNegotiation)1368   MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1369       : BaseClass(C, LazyAutoNegotiation) {}
1370 
1371   /// Add a handler for the given RPC function.
1372   /// This installs the given handler functor for the given RPC Function, and
1373   /// makes the RPC function available for negotiation/calling from the remote.
1374   template <typename Func, typename HandlerT>
addHandler(HandlerT Handler)1375   void addHandler(HandlerT Handler) {
1376     return this->template addHandlerImpl<Func>(std::move(Handler));
1377   }
1378 
1379   /// Add a class-method as a handler.
1380   template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1381   void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1382     addHandler<Func>(
1383       detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1384   }
1385 
1386   template <typename Func, typename HandlerT>
addAsyncHandler(HandlerT Handler)1387   void addAsyncHandler(HandlerT Handler) {
1388     return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1389   }
1390 
1391   /// Add a class-method as a handler.
1392   template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addAsyncHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1393   void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1394     addAsyncHandler<Func>(
1395       detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1396   }
1397 
1398   /// Return type for non-blocking call primitives.
1399   template <typename Func>
1400   using NonBlockingCallResult = typename detail::ResultTraits<
1401       typename Func::ReturnType>::ReturnFutureType;
1402 
1403   /// Call Func on Channel C. Does not block, does not call send. Returns a pair
1404   /// of a future result and the sequence number assigned to the result.
1405   ///
1406   /// This utility function is primarily used for single-threaded mode support,
1407   /// where the sequence number can be used to wait for the corresponding
1408   /// result. In multi-threaded mode the appendCallNB method, which does not
1409   /// return the sequence numeber, should be preferred.
1410   template <typename Func, typename... ArgTs>
appendCallNB(const ArgTs &...Args)1411   Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &... Args) {
1412     using RTraits = detail::ResultTraits<typename Func::ReturnType>;
1413     using ErrorReturn = typename RTraits::ErrorReturnType;
1414     using ErrorReturnPromise = typename RTraits::ReturnPromiseType;
1415 
1416     ErrorReturnPromise Promise;
1417     auto FutureResult = Promise.get_future();
1418 
1419     if (auto Err = this->template appendCallAsync<Func>(
1420             [Promise = std::move(Promise)](ErrorReturn RetOrErr) mutable {
1421               Promise.set_value(std::move(RetOrErr));
1422               return Error::success();
1423             },
1424             Args...)) {
1425       RTraits::consumeAbandoned(FutureResult.get());
1426       return std::move(Err);
1427     }
1428     return std::move(FutureResult);
1429   }
1430 
1431   /// The same as appendCallNBWithSeq, except that it calls C.send() to
1432   /// flush the channel after serializing the call.
1433   template <typename Func, typename... ArgTs>
callNB(const ArgTs &...Args)1434   Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &... Args) {
1435     auto Result = appendCallNB<Func>(Args...);
1436     if (!Result)
1437       return Result;
1438     if (auto Err = this->C.send()) {
1439       this->abandonPendingResponses();
1440       detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1441           std::move(Result->get()));
1442       return std::move(Err);
1443     }
1444     return Result;
1445   }
1446 
1447   /// Call Func on Channel C. Blocks waiting for a result. Returns an Error
1448   /// for void functions or an Expected<T> for functions returning a T.
1449   ///
1450   /// This function is for use in threaded code where another thread is
1451   /// handling responses and incoming calls.
1452   template <typename Func, typename... ArgTs,
1453             typename AltRetT = typename Func::ReturnType>
1454   typename detail::ResultTraits<AltRetT>::ErrorReturnType
callB(const ArgTs &...Args)1455   callB(const ArgTs &... Args) {
1456     if (auto FutureResOrErr = callNB<Func>(Args...))
1457       return FutureResOrErr->get();
1458     else
1459       return FutureResOrErr.takeError();
1460   }
1461 
1462   /// Handle incoming RPC calls.
handlerLoop()1463   Error handlerLoop() {
1464     while (true)
1465       if (auto Err = this->handleOne())
1466         return Err;
1467     return Error::success();
1468   }
1469 };
1470 
1471 template <typename ChannelT, typename FunctionIdT = uint32_t,
1472           typename SequenceNumberT = uint32_t>
1473 class SingleThreadedRPCEndpoint
1474     : public detail::RPCEndpointBase<
1475           SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1476           ChannelT, FunctionIdT, SequenceNumberT> {
1477 private:
1478   using BaseClass =
1479       detail::RPCEndpointBase<
1480         SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1481         ChannelT, FunctionIdT, SequenceNumberT>;
1482 
1483 public:
SingleThreadedRPCEndpoint(ChannelT & C,bool LazyAutoNegotiation)1484   SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1485       : BaseClass(C, LazyAutoNegotiation) {}
1486 
1487   template <typename Func, typename HandlerT>
addHandler(HandlerT Handler)1488   void addHandler(HandlerT Handler) {
1489     return this->template addHandlerImpl<Func>(std::move(Handler));
1490   }
1491 
1492   template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1493   void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1494     addHandler<Func>(
1495         detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1496   }
1497 
1498   template <typename Func, typename HandlerT>
addAsyncHandler(HandlerT Handler)1499   void addAsyncHandler(HandlerT Handler) {
1500     return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1501   }
1502 
1503   /// Add a class-method as a handler.
1504   template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addAsyncHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1505   void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1506     addAsyncHandler<Func>(
1507       detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1508   }
1509 
1510   template <typename Func, typename... ArgTs,
1511             typename AltRetT = typename Func::ReturnType>
1512   typename detail::ResultTraits<AltRetT>::ErrorReturnType
callB(const ArgTs &...Args)1513   callB(const ArgTs &... Args) {
1514     bool ReceivedResponse = false;
1515     using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType;
1516     auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue();
1517 
1518     // We have to 'Check' result (which we know is in a success state at this
1519     // point) so that it can be overwritten in the async handler.
1520     (void)!!Result;
1521 
1522     if (auto Err = this->template appendCallAsync<Func>(
1523             [&](ResultType R) {
1524               Result = std::move(R);
1525               ReceivedResponse = true;
1526               return Error::success();
1527             },
1528             Args...)) {
1529       detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1530           std::move(Result));
1531       return std::move(Err);
1532     }
1533 
1534     if (auto Err = this->C.send()) {
1535       detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1536           std::move(Result));
1537       return std::move(Err);
1538     }
1539 
1540     while (!ReceivedResponse) {
1541       if (auto Err = this->handleOne()) {
1542         detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1543             std::move(Result));
1544         return std::move(Err);
1545       }
1546     }
1547 
1548     return Result;
1549   }
1550 };
1551 
1552 /// Asynchronous dispatch for a function on an RPC endpoint.
1553 template <typename RPCClass, typename Func>
1554 class RPCAsyncDispatch {
1555 public:
RPCAsyncDispatch(RPCClass & Endpoint)1556   RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {}
1557 
1558   template <typename HandlerT, typename... ArgTs>
operator()1559   Error operator()(HandlerT Handler, const ArgTs &... Args) const {
1560     return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...);
1561   }
1562 
1563 private:
1564   RPCClass &Endpoint;
1565 };
1566 
1567 /// Construct an asynchronous dispatcher from an RPC endpoint and a Func.
1568 template <typename Func, typename RPCEndpointT>
rpcAsyncDispatch(RPCEndpointT & Endpoint)1569 RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) {
1570   return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint);
1571 }
1572 
1573 /// Allows a set of asynchrounous calls to be dispatched, and then
1574 ///        waited on as a group.
1575 class ParallelCallGroup {
1576 public:
1577 
1578   ParallelCallGroup() = default;
1579   ParallelCallGroup(const ParallelCallGroup &) = delete;
1580   ParallelCallGroup &operator=(const ParallelCallGroup &) = delete;
1581 
1582   /// Make as asynchronous call.
1583   template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs>
call(const AsyncDispatcher & AsyncDispatch,HandlerT Handler,const ArgTs &...Args)1584   Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler,
1585              const ArgTs &... Args) {
1586     // Increment the count of outstanding calls. This has to happen before
1587     // we invoke the call, as the handler may (depending on scheduling)
1588     // be run immediately on another thread, and we don't want the decrement
1589     // in the wrapped handler below to run before the increment.
1590     {
1591       std::unique_lock<std::mutex> Lock(M);
1592       ++NumOutstandingCalls;
1593     }
1594 
1595     // Wrap the user handler in a lambda that will decrement the
1596     // outstanding calls count, then poke the condition variable.
1597     using ArgType = typename detail::ResponseHandlerArg<
1598         typename detail::HandlerTraits<HandlerT>::Type>::ArgType;
1599     auto WrappedHandler = [this, Handler = std::move(Handler)](ArgType Arg) {
1600       auto Err = Handler(std::move(Arg));
1601       std::unique_lock<std::mutex> Lock(M);
1602       --NumOutstandingCalls;
1603       CV.notify_all();
1604       return Err;
1605     };
1606 
1607     return AsyncDispatch(std::move(WrappedHandler), Args...);
1608   }
1609 
1610   /// Blocks until all calls have been completed and their return value
1611   ///        handlers run.
wait()1612   void wait() {
1613     std::unique_lock<std::mutex> Lock(M);
1614     while (NumOutstandingCalls > 0)
1615       CV.wait(Lock);
1616   }
1617 
1618 private:
1619   std::mutex M;
1620   std::condition_variable CV;
1621   uint32_t NumOutstandingCalls = 0;
1622 };
1623 
1624 /// Convenience class for grouping RPC Functions into APIs that can be
1625 ///        negotiated as a block.
1626 ///
1627 template <typename... Funcs>
1628 class APICalls {
1629 public:
1630 
1631   /// Test whether this API contains Function F.
1632   template <typename F>
1633   class Contains {
1634   public:
1635     static const bool value = false;
1636   };
1637 
1638   /// Negotiate all functions in this API.
1639   template <typename RPCEndpoint>
negotiate(RPCEndpoint & R)1640   static Error negotiate(RPCEndpoint &R) {
1641     return Error::success();
1642   }
1643 };
1644 
1645 template <typename Func, typename... Funcs>
1646 class APICalls<Func, Funcs...> {
1647 public:
1648 
1649   template <typename F>
1650   class Contains {
1651   public:
1652     static const bool value = std::is_same<F, Func>::value |
1653                               APICalls<Funcs...>::template Contains<F>::value;
1654   };
1655 
1656   template <typename RPCEndpoint>
negotiate(RPCEndpoint & R)1657   static Error negotiate(RPCEndpoint &R) {
1658     if (auto Err = R.template negotiateFunction<Func>())
1659       return Err;
1660     return APICalls<Funcs...>::negotiate(R);
1661   }
1662 
1663 };
1664 
1665 template <typename... InnerFuncs, typename... Funcs>
1666 class APICalls<APICalls<InnerFuncs...>, Funcs...> {
1667 public:
1668 
1669   template <typename F>
1670   class Contains {
1671   public:
1672     static const bool value =
1673       APICalls<InnerFuncs...>::template Contains<F>::value |
1674       APICalls<Funcs...>::template Contains<F>::value;
1675   };
1676 
1677   template <typename RPCEndpoint>
negotiate(RPCEndpoint & R)1678   static Error negotiate(RPCEndpoint &R) {
1679     if (auto Err = APICalls<InnerFuncs...>::negotiate(R))
1680       return Err;
1681     return APICalls<Funcs...>::negotiate(R);
1682   }
1683 
1684 };
1685 
1686 } // end namespace rpc
1687 } // end namespace orc
1688 } // end namespace llvm
1689 
1690 #endif
1691