1 //===------- RPCUTils.h - Utilities for building RPC APIs -------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Utilities to support construction of simple RPC APIs.
11 //
12 // The RPC utilities aim for ease of use (minimal conceptual overhead) for C++
13 // programmers, high performance, low memory overhead, and efficient use of the
14 // communications channel.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
19 #define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
20 
21 #include <map>
22 #include <thread>
23 #include <vector>
24 
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ExecutionEngine/Orc/OrcError.h"
27 #include "llvm/ExecutionEngine/Orc/RPCSerialization.h"
28 #include "llvm/Support/MSVCErrorWorkarounds.h"
29 
30 #include <future>
31 
32 namespace llvm {
33 namespace orc {
34 namespace rpc {
35 
36 /// Base class of all fatal RPC errors (those that necessarily result in the
37 /// termination of the RPC session).
38 class RPCFatalError : public ErrorInfo<RPCFatalError> {
39 public:
40   static char ID;
41 };
42 
43 /// RPCConnectionClosed is returned from RPC operations if the RPC connection
44 /// has already been closed due to either an error or graceful disconnection.
45 class ConnectionClosed : public ErrorInfo<ConnectionClosed> {
46 public:
47   static char ID;
48   std::error_code convertToErrorCode() const override;
49   void log(raw_ostream &OS) const override;
50 };
51 
52 /// BadFunctionCall is returned from handleOne when the remote makes a call with
53 /// an unrecognized function id.
54 ///
55 /// This error is fatal because Orc RPC needs to know how to parse a function
56 /// call to know where the next call starts, and if it doesn't recognize the
57 /// function id it cannot parse the call.
58 template <typename FnIdT, typename SeqNoT>
59 class BadFunctionCall
60   : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> {
61 public:
62   static char ID;
63 
BadFunctionCall(FnIdT FnId,SeqNoT SeqNo)64   BadFunctionCall(FnIdT FnId, SeqNoT SeqNo)
65       : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {}
66 
convertToErrorCode()67   std::error_code convertToErrorCode() const override {
68     return orcError(OrcErrorCode::UnexpectedRPCCall);
69   }
70 
log(raw_ostream & OS)71   void log(raw_ostream &OS) const override {
72     OS << "Call to invalid RPC function id '" << FnId << "' with "
73           "sequence number " << SeqNo;
74   }
75 
76 private:
77   FnIdT FnId;
78   SeqNoT SeqNo;
79 };
80 
81 template <typename FnIdT, typename SeqNoT>
82 char BadFunctionCall<FnIdT, SeqNoT>::ID = 0;
83 
84 /// InvalidSequenceNumberForResponse is returned from handleOne when a response
85 /// call arrives with a sequence number that doesn't correspond to any in-flight
86 /// function call.
87 ///
88 /// This error is fatal because Orc RPC needs to know how to parse the rest of
89 /// the response call to know where the next call starts, and if it doesn't have
90 /// a result parser for this sequence number it can't do that.
91 template <typename SeqNoT>
92 class InvalidSequenceNumberForResponse
93     : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> {
94 public:
95   static char ID;
96 
InvalidSequenceNumberForResponse(SeqNoT SeqNo)97   InvalidSequenceNumberForResponse(SeqNoT SeqNo)
98       : SeqNo(std::move(SeqNo)) {}
99 
convertToErrorCode()100   std::error_code convertToErrorCode() const override {
101     return orcError(OrcErrorCode::UnexpectedRPCCall);
102   };
103 
log(raw_ostream & OS)104   void log(raw_ostream &OS) const override {
105     OS << "Response has unknown sequence number " << SeqNo;
106   }
107 private:
108   SeqNoT SeqNo;
109 };
110 
111 template <typename SeqNoT>
112 char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0;
113 
114 /// This non-fatal error will be passed to asynchronous result handlers in place
115 /// of a result if the connection goes down before a result returns, or if the
116 /// function to be called cannot be negotiated with the remote.
117 class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> {
118 public:
119   static char ID;
120 
121   std::error_code convertToErrorCode() const override;
122   void log(raw_ostream &OS) const override;
123 };
124 
125 /// This error is returned if the remote does not have a handler installed for
126 /// the given RPC function.
127 class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> {
128 public:
129   static char ID;
130 
131   CouldNotNegotiate(std::string Signature);
132   std::error_code convertToErrorCode() const override;
133   void log(raw_ostream &OS) const override;
getSignature()134   const std::string &getSignature() const { return Signature; }
135 private:
136   std::string Signature;
137 };
138 
139 template <typename DerivedFunc, typename FnT> class Function;
140 
141 // RPC Function class.
142 // DerivedFunc should be a user defined class with a static 'getName()' method
143 // returning a const char* representing the function's name.
144 template <typename DerivedFunc, typename RetT, typename... ArgTs>
145 class Function<DerivedFunc, RetT(ArgTs...)> {
146 public:
147   /// User defined function type.
148   using Type = RetT(ArgTs...);
149 
150   /// Return type.
151   using ReturnType = RetT;
152 
153   /// Returns the full function prototype as a string.
getPrototype()154   static const char *getPrototype() {
155     std::lock_guard<std::mutex> Lock(NameMutex);
156     if (Name.empty())
157       raw_string_ostream(Name)
158           << RPCTypeName<RetT>::getName() << " " << DerivedFunc::getName()
159           << "(" << llvm::orc::rpc::RPCTypeNameSequence<ArgTs...>() << ")";
160     return Name.data();
161   }
162 
163 private:
164   static std::mutex NameMutex;
165   static std::string Name;
166 };
167 
168 template <typename DerivedFunc, typename RetT, typename... ArgTs>
169 std::mutex Function<DerivedFunc, RetT(ArgTs...)>::NameMutex;
170 
171 template <typename DerivedFunc, typename RetT, typename... ArgTs>
172 std::string Function<DerivedFunc, RetT(ArgTs...)>::Name;
173 
174 /// Allocates RPC function ids during autonegotiation.
175 /// Specializations of this class must provide four members:
176 ///
177 /// static T getInvalidId():
178 ///   Should return a reserved id that will be used to represent missing
179 /// functions during autonegotiation.
180 ///
181 /// static T getResponseId():
182 ///   Should return a reserved id that will be used to send function responses
183 /// (return values).
184 ///
185 /// static T getNegotiateId():
186 ///   Should return a reserved id for the negotiate function, which will be used
187 /// to negotiate ids for user defined functions.
188 ///
189 /// template <typename Func> T allocate():
190 ///   Allocate a unique id for function Func.
191 template <typename T, typename = void> class RPCFunctionIdAllocator;
192 
193 /// This specialization of RPCFunctionIdAllocator provides a default
194 /// implementation for integral types.
195 template <typename T>
196 class RPCFunctionIdAllocator<
197     T, typename std::enable_if<std::is_integral<T>::value>::type> {
198 public:
getInvalidId()199   static T getInvalidId() { return T(0); }
getResponseId()200   static T getResponseId() { return T(1); }
getNegotiateId()201   static T getNegotiateId() { return T(2); }
202 
allocate()203   template <typename Func> T allocate() { return NextId++; }
204 
205 private:
206   T NextId = 3;
207 };
208 
209 namespace detail {
210 
211 /// Provides a typedef for a tuple containing the decayed argument types.
212 template <typename T> class FunctionArgsTuple;
213 
214 template <typename RetT, typename... ArgTs>
215 class FunctionArgsTuple<RetT(ArgTs...)> {
216 public:
217   using Type = std::tuple<typename std::decay<
218       typename std::remove_reference<ArgTs>::type>::type...>;
219 };
220 
221 // ResultTraits provides typedefs and utilities specific to the return type
222 // of functions.
223 template <typename RetT> class ResultTraits {
224 public:
225   // The return type wrapped in llvm::Expected.
226   using ErrorReturnType = Expected<RetT>;
227 
228 #ifdef _MSC_VER
229   // The ErrorReturnType wrapped in a std::promise.
230   using ReturnPromiseType = std::promise<MSVCPExpected<RetT>>;
231 
232   // The ErrorReturnType wrapped in a std::future.
233   using ReturnFutureType = std::future<MSVCPExpected<RetT>>;
234 #else
235   // The ErrorReturnType wrapped in a std::promise.
236   using ReturnPromiseType = std::promise<ErrorReturnType>;
237 
238   // The ErrorReturnType wrapped in a std::future.
239   using ReturnFutureType = std::future<ErrorReturnType>;
240 #endif
241 
242   // Create a 'blank' value of the ErrorReturnType, ready and safe to
243   // overwrite.
createBlankErrorReturnValue()244   static ErrorReturnType createBlankErrorReturnValue() {
245     return ErrorReturnType(RetT());
246   }
247 
248   // Consume an abandoned ErrorReturnType.
consumeAbandoned(ErrorReturnType RetOrErr)249   static void consumeAbandoned(ErrorReturnType RetOrErr) {
250     consumeError(RetOrErr.takeError());
251   }
252 };
253 
254 // ResultTraits specialization for void functions.
255 template <> class ResultTraits<void> {
256 public:
257   // For void functions, ErrorReturnType is llvm::Error.
258   using ErrorReturnType = Error;
259 
260 #ifdef _MSC_VER
261   // The ErrorReturnType wrapped in a std::promise.
262   using ReturnPromiseType = std::promise<MSVCPError>;
263 
264   // The ErrorReturnType wrapped in a std::future.
265   using ReturnFutureType = std::future<MSVCPError>;
266 #else
267   // The ErrorReturnType wrapped in a std::promise.
268   using ReturnPromiseType = std::promise<ErrorReturnType>;
269 
270   // The ErrorReturnType wrapped in a std::future.
271   using ReturnFutureType = std::future<ErrorReturnType>;
272 #endif
273 
274   // Create a 'blank' value of the ErrorReturnType, ready and safe to
275   // overwrite.
createBlankErrorReturnValue()276   static ErrorReturnType createBlankErrorReturnValue() {
277     return ErrorReturnType::success();
278   }
279 
280   // Consume an abandoned ErrorReturnType.
consumeAbandoned(ErrorReturnType Err)281   static void consumeAbandoned(ErrorReturnType Err) {
282     consumeError(std::move(Err));
283   }
284 };
285 
286 // ResultTraits<Error> is equivalent to ResultTraits<void>. This allows
287 // handlers for void RPC functions to return either void (in which case they
288 // implicitly succeed) or Error (in which case their error return is
289 // propagated). See usage in HandlerTraits::runHandlerHelper.
290 template <> class ResultTraits<Error> : public ResultTraits<void> {};
291 
292 // ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows
293 // handlers for RPC functions returning a T to return either a T (in which
294 // case they implicitly succeed) or Expected<T> (in which case their error
295 // return is propagated). See usage in HandlerTraits::runHandlerHelper.
296 template <typename RetT>
297 class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {};
298 
299 // Determines whether an RPC function's defined error return type supports
300 // error return value.
301 template <typename T>
302 class SupportsErrorReturn {
303 public:
304   static const bool value = false;
305 };
306 
307 template <>
308 class SupportsErrorReturn<Error> {
309 public:
310   static const bool value = true;
311 };
312 
313 template <typename T>
314 class SupportsErrorReturn<Expected<T>> {
315 public:
316   static const bool value = true;
317 };
318 
319 // RespondHelper packages return values based on whether or not the declared
320 // RPC function return type supports error returns.
321 template <bool FuncSupportsErrorReturn>
322 class RespondHelper;
323 
324 // RespondHelper specialization for functions that support error returns.
325 template <>
326 class RespondHelper<true> {
327 public:
328 
329   // Send Expected<T>.
330   template <typename WireRetT, typename HandlerRetT, typename ChannelT,
331             typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Expected<HandlerRetT> ResultOrErr)332   static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
333                           SequenceNumberT SeqNo,
334                           Expected<HandlerRetT> ResultOrErr) {
335     if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>())
336       return ResultOrErr.takeError();
337 
338     // Open the response message.
339     if (auto Err = C.startSendMessage(ResponseId, SeqNo))
340       return Err;
341 
342     // Serialize the result.
343     if (auto Err =
344         SerializationTraits<ChannelT, WireRetT,
345                             Expected<HandlerRetT>>::serialize(
346                                                      C, std::move(ResultOrErr)))
347       return Err;
348 
349     // Close the response message.
350     return C.endSendMessage();
351   }
352 
353   template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Error Err)354   static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
355                           SequenceNumberT SeqNo, Error Err) {
356     if (Err && Err.isA<RPCFatalError>())
357       return Err;
358     if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
359       return Err2;
360     if (auto Err2 = serializeSeq(C, std::move(Err)))
361       return Err2;
362     return C.endSendMessage();
363   }
364 
365 };
366 
367 // RespondHelper specialization for functions that do not support error returns.
368 template <>
369 class RespondHelper<false> {
370 public:
371 
372   template <typename WireRetT, typename HandlerRetT, typename ChannelT,
373             typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Expected<HandlerRetT> ResultOrErr)374   static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
375                           SequenceNumberT SeqNo,
376                           Expected<HandlerRetT> ResultOrErr) {
377     if (auto Err = ResultOrErr.takeError())
378       return Err;
379 
380     // Open the response message.
381     if (auto Err = C.startSendMessage(ResponseId, SeqNo))
382       return Err;
383 
384     // Serialize the result.
385     if (auto Err =
386         SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
387                                                                C, *ResultOrErr))
388       return Err;
389 
390     // Close the response message.
391     return C.endSendMessage();
392   }
393 
394   template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Error Err)395   static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
396                           SequenceNumberT SeqNo, Error Err) {
397     if (Err)
398       return Err;
399     if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
400       return Err2;
401     return C.endSendMessage();
402   }
403 
404 };
405 
406 
407 // Send a response of the given wire return type (WireRetT) over the
408 // channel, with the given sequence number.
409 template <typename WireRetT, typename HandlerRetT, typename ChannelT,
410           typename FunctionIdT, typename SequenceNumberT>
respond(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Expected<HandlerRetT> ResultOrErr)411 Error respond(ChannelT &C, const FunctionIdT &ResponseId,
412               SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) {
413   return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
414     template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr));
415 }
416 
417 // Send an empty response message on the given channel to indicate that
418 // the handler ran.
419 template <typename WireRetT, typename ChannelT, typename FunctionIdT,
420           typename SequenceNumberT>
respond(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Error Err)421 Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
422               Error Err) {
423   return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
424     sendResult(C, ResponseId, SeqNo, std::move(Err));
425 }
426 
427 // Converts a given type to the equivalent error return type.
428 template <typename T> class WrappedHandlerReturn {
429 public:
430   using Type = Expected<T>;
431 };
432 
433 template <typename T> class WrappedHandlerReturn<Expected<T>> {
434 public:
435   using Type = Expected<T>;
436 };
437 
438 template <> class WrappedHandlerReturn<void> {
439 public:
440   using Type = Error;
441 };
442 
443 template <> class WrappedHandlerReturn<Error> {
444 public:
445   using Type = Error;
446 };
447 
448 template <> class WrappedHandlerReturn<ErrorSuccess> {
449 public:
450   using Type = Error;
451 };
452 
453 // Traits class that strips the response function from the list of handler
454 // arguments.
455 template <typename FnT> class AsyncHandlerTraits;
456 
457 template <typename ResultT, typename... ArgTs>
458 class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> {
459 public:
460   using Type = Error(ArgTs...);
461   using ResultType = Expected<ResultT>;
462 };
463 
464 template <typename... ArgTs>
465 class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> {
466 public:
467   using Type = Error(ArgTs...);
468   using ResultType = Error;
469 };
470 
471 template <typename... ArgTs>
472 class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> {
473 public:
474   using Type = Error(ArgTs...);
475   using ResultType = Error;
476 };
477 
478 template <typename... ArgTs>
479 class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> {
480 public:
481   using Type = Error(ArgTs...);
482   using ResultType = Error;
483 };
484 
485 template <typename ResponseHandlerT, typename... ArgTs>
486 class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> :
487     public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type,
488                                     ArgTs...)> {};
489 
490 // This template class provides utilities related to RPC function handlers.
491 // The base case applies to non-function types (the template class is
492 // specialized for function types) and inherits from the appropriate
493 // speciilization for the given non-function type's call operator.
494 template <typename HandlerT>
495 class HandlerTraits : public HandlerTraits<decltype(
496                           &std::remove_reference<HandlerT>::type::operator())> {
497 };
498 
499 // Traits for handlers with a given function type.
500 template <typename RetT, typename... ArgTs>
501 class HandlerTraits<RetT(ArgTs...)> {
502 public:
503   // Function type of the handler.
504   using Type = RetT(ArgTs...);
505 
506   // Return type of the handler.
507   using ReturnType = RetT;
508 
509   // Call the given handler with the given arguments.
510   template <typename HandlerT, typename... TArgTs>
511   static typename WrappedHandlerReturn<RetT>::Type
unpackAndRun(HandlerT & Handler,std::tuple<TArgTs...> & Args)512   unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) {
513     return unpackAndRunHelper(Handler, Args,
514                               llvm::index_sequence_for<TArgTs...>());
515   }
516 
517   // Call the given handler with the given arguments.
518   template <typename HandlerT, typename ResponderT, typename... TArgTs>
unpackAndRunAsync(HandlerT & Handler,ResponderT & Responder,std::tuple<TArgTs...> & Args)519   static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder,
520                                  std::tuple<TArgTs...> &Args) {
521     return unpackAndRunAsyncHelper(Handler, Responder, Args,
522                                    llvm::index_sequence_for<TArgTs...>());
523   }
524 
525   // Call the given handler with the given arguments.
526   template <typename HandlerT>
527   static typename std::enable_if<
528       std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
529       Error>::type
run(HandlerT & Handler,ArgTs &&...Args)530   run(HandlerT &Handler, ArgTs &&... Args) {
531     Handler(std::move(Args)...);
532     return Error::success();
533   }
534 
535   template <typename HandlerT, typename... TArgTs>
536   static typename std::enable_if<
537       !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
538       typename HandlerTraits<HandlerT>::ReturnType>::type
run(HandlerT & Handler,TArgTs...Args)539   run(HandlerT &Handler, TArgTs... Args) {
540     return Handler(std::move(Args)...);
541   }
542 
543   // Serialize arguments to the channel.
544   template <typename ChannelT, typename... CArgTs>
serializeArgs(ChannelT & C,const CArgTs...CArgs)545   static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) {
546     return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...);
547   }
548 
549   // Deserialize arguments from the channel.
550   template <typename ChannelT, typename... CArgTs>
deserializeArgs(ChannelT & C,std::tuple<CArgTs...> & Args)551   static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) {
552     return deserializeArgsHelper(C, Args,
553                                  llvm::index_sequence_for<CArgTs...>());
554   }
555 
556 private:
557   template <typename ChannelT, typename... CArgTs, size_t... Indexes>
deserializeArgsHelper(ChannelT & C,std::tuple<CArgTs...> & Args,llvm::index_sequence<Indexes...> _)558   static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args,
559                                      llvm::index_sequence<Indexes...> _) {
560     return SequenceSerialization<ChannelT, ArgTs...>::deserialize(
561         C, std::get<Indexes>(Args)...);
562   }
563 
564   template <typename HandlerT, typename ArgTuple, size_t... Indexes>
565   static typename WrappedHandlerReturn<
566       typename HandlerTraits<HandlerT>::ReturnType>::Type
unpackAndRunHelper(HandlerT & Handler,ArgTuple & Args,llvm::index_sequence<Indexes...>)567   unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args,
568                      llvm::index_sequence<Indexes...>) {
569     return run(Handler, std::move(std::get<Indexes>(Args))...);
570   }
571 
572 
573   template <typename HandlerT, typename ResponderT, typename ArgTuple,
574             size_t... Indexes>
575   static typename WrappedHandlerReturn<
576       typename HandlerTraits<HandlerT>::ReturnType>::Type
unpackAndRunAsyncHelper(HandlerT & Handler,ResponderT & Responder,ArgTuple & Args,llvm::index_sequence<Indexes...>)577   unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder,
578                           ArgTuple &Args,
579                           llvm::index_sequence<Indexes...>) {
580     return run(Handler, Responder, std::move(std::get<Indexes>(Args))...);
581   }
582 };
583 
584 // Handler traits for free functions.
585 template <typename RetT, typename... ArgTs>
586 class HandlerTraits<RetT(*)(ArgTs...)>
587   : public HandlerTraits<RetT(ArgTs...)> {};
588 
589 // Handler traits for class methods (especially call operators for lambdas).
590 template <typename Class, typename RetT, typename... ArgTs>
591 class HandlerTraits<RetT (Class::*)(ArgTs...)>
592     : public HandlerTraits<RetT(ArgTs...)> {};
593 
594 // Handler traits for const class methods (especially call operators for
595 // lambdas).
596 template <typename Class, typename RetT, typename... ArgTs>
597 class HandlerTraits<RetT (Class::*)(ArgTs...) const>
598     : public HandlerTraits<RetT(ArgTs...)> {};
599 
600 // Utility to peel the Expected wrapper off a response handler error type.
601 template <typename HandlerT> class ResponseHandlerArg;
602 
603 template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> {
604 public:
605   using ArgType = Expected<ArgT>;
606   using UnwrappedArgType = ArgT;
607 };
608 
609 template <typename ArgT>
610 class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> {
611 public:
612   using ArgType = Expected<ArgT>;
613   using UnwrappedArgType = ArgT;
614 };
615 
616 template <> class ResponseHandlerArg<Error(Error)> {
617 public:
618   using ArgType = Error;
619 };
620 
621 template <> class ResponseHandlerArg<ErrorSuccess(Error)> {
622 public:
623   using ArgType = Error;
624 };
625 
626 // ResponseHandler represents a handler for a not-yet-received function call
627 // result.
628 template <typename ChannelT> class ResponseHandler {
629 public:
~ResponseHandler()630   virtual ~ResponseHandler() {}
631 
632   // Reads the function result off the wire and acts on it. The meaning of
633   // "act" will depend on how this method is implemented in any given
634   // ResponseHandler subclass but could, for example, mean running a
635   // user-specified handler or setting a promise value.
636   virtual Error handleResponse(ChannelT &C) = 0;
637 
638   // Abandons this outstanding result.
639   virtual void abandon() = 0;
640 
641   // Create an error instance representing an abandoned response.
createAbandonedResponseError()642   static Error createAbandonedResponseError() {
643     return make_error<ResponseAbandoned>();
644   }
645 };
646 
647 // ResponseHandler subclass for RPC functions with non-void returns.
648 template <typename ChannelT, typename FuncRetT, typename HandlerT>
649 class ResponseHandlerImpl : public ResponseHandler<ChannelT> {
650 public:
ResponseHandlerImpl(HandlerT Handler)651   ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
652 
653   // Handle the result by deserializing it from the channel then passing it
654   // to the user defined handler.
handleResponse(ChannelT & C)655   Error handleResponse(ChannelT &C) override {
656     using UnwrappedArgType = typename ResponseHandlerArg<
657         typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType;
658     UnwrappedArgType Result;
659     if (auto Err =
660             SerializationTraits<ChannelT, FuncRetT,
661                                 UnwrappedArgType>::deserialize(C, Result))
662       return Err;
663     if (auto Err = C.endReceiveMessage())
664       return Err;
665     return Handler(std::move(Result));
666   }
667 
668   // Abandon this response by calling the handler with an 'abandoned response'
669   // error.
abandon()670   void abandon() override {
671     if (auto Err = Handler(this->createAbandonedResponseError())) {
672       // Handlers should not fail when passed an abandoned response error.
673       report_fatal_error(std::move(Err));
674     }
675   }
676 
677 private:
678   HandlerT Handler;
679 };
680 
681 // ResponseHandler subclass for RPC functions with void returns.
682 template <typename ChannelT, typename HandlerT>
683 class ResponseHandlerImpl<ChannelT, void, HandlerT>
684     : public ResponseHandler<ChannelT> {
685 public:
ResponseHandlerImpl(HandlerT Handler)686   ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
687 
688   // Handle the result (no actual value, just a notification that the function
689   // has completed on the remote end) by calling the user-defined handler with
690   // Error::success().
handleResponse(ChannelT & C)691   Error handleResponse(ChannelT &C) override {
692     if (auto Err = C.endReceiveMessage())
693       return Err;
694     return Handler(Error::success());
695   }
696 
697   // Abandon this response by calling the handler with an 'abandoned response'
698   // error.
abandon()699   void abandon() override {
700     if (auto Err = Handler(this->createAbandonedResponseError())) {
701       // Handlers should not fail when passed an abandoned response error.
702       report_fatal_error(std::move(Err));
703     }
704   }
705 
706 private:
707   HandlerT Handler;
708 };
709 
710 template <typename ChannelT, typename FuncRetT, typename HandlerT>
711 class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT>
712     : public ResponseHandler<ChannelT> {
713 public:
ResponseHandlerImpl(HandlerT Handler)714   ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
715 
716   // Handle the result by deserializing it from the channel then passing it
717   // to the user defined handler.
handleResponse(ChannelT & C)718   Error handleResponse(ChannelT &C) override {
719     using HandlerArgType = typename ResponseHandlerArg<
720         typename HandlerTraits<HandlerT>::Type>::ArgType;
721     HandlerArgType Result((typename HandlerArgType::value_type()));
722 
723     if (auto Err =
724             SerializationTraits<ChannelT, Expected<FuncRetT>,
725                                 HandlerArgType>::deserialize(C, Result))
726       return Err;
727     if (auto Err = C.endReceiveMessage())
728       return Err;
729     return Handler(std::move(Result));
730   }
731 
732   // Abandon this response by calling the handler with an 'abandoned response'
733   // error.
abandon()734   void abandon() override {
735     if (auto Err = Handler(this->createAbandonedResponseError())) {
736       // Handlers should not fail when passed an abandoned response error.
737       report_fatal_error(std::move(Err));
738     }
739   }
740 
741 private:
742   HandlerT Handler;
743 };
744 
745 template <typename ChannelT, typename HandlerT>
746 class ResponseHandlerImpl<ChannelT, Error, HandlerT>
747     : public ResponseHandler<ChannelT> {
748 public:
ResponseHandlerImpl(HandlerT Handler)749   ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
750 
751   // Handle the result by deserializing it from the channel then passing it
752   // to the user defined handler.
handleResponse(ChannelT & C)753   Error handleResponse(ChannelT &C) override {
754     Error Result = Error::success();
755     if (auto Err =
756             SerializationTraits<ChannelT, Error, Error>::deserialize(C, Result))
757       return Err;
758     if (auto Err = C.endReceiveMessage())
759       return Err;
760     return Handler(std::move(Result));
761   }
762 
763   // Abandon this response by calling the handler with an 'abandoned response'
764   // error.
abandon()765   void abandon() override {
766     if (auto Err = Handler(this->createAbandonedResponseError())) {
767       // Handlers should not fail when passed an abandoned response error.
768       report_fatal_error(std::move(Err));
769     }
770   }
771 
772 private:
773   HandlerT Handler;
774 };
775 
776 // Create a ResponseHandler from a given user handler.
777 template <typename ChannelT, typename FuncRetT, typename HandlerT>
createResponseHandler(HandlerT H)778 std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) {
779   return llvm::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>(
780       std::move(H));
781 }
782 
783 // Helper for wrapping member functions up as functors. This is useful for
784 // installing methods as result handlers.
785 template <typename ClassT, typename RetT, typename... ArgTs>
786 class MemberFnWrapper {
787 public:
788   using MethodT = RetT (ClassT::*)(ArgTs...);
MemberFnWrapper(ClassT & Instance,MethodT Method)789   MemberFnWrapper(ClassT &Instance, MethodT Method)
790       : Instance(Instance), Method(Method) {}
operator()791   RetT operator()(ArgTs &&... Args) {
792     return (Instance.*Method)(std::move(Args)...);
793   }
794 
795 private:
796   ClassT &Instance;
797   MethodT Method;
798 };
799 
800 // Helper that provides a Functor for deserializing arguments.
801 template <typename... ArgTs> class ReadArgs {
802 public:
operator()803   Error operator()() { return Error::success(); }
804 };
805 
806 template <typename ArgT, typename... ArgTs>
807 class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> {
808 public:
ReadArgs(ArgT & Arg,ArgTs &...Args)809   ReadArgs(ArgT &Arg, ArgTs &... Args)
810       : ReadArgs<ArgTs...>(Args...), Arg(Arg) {}
811 
operator()812   Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) {
813     this->Arg = std::move(ArgVal);
814     return ReadArgs<ArgTs...>::operator()(ArgVals...);
815   }
816 
817 private:
818   ArgT &Arg;
819 };
820 
821 // Manage sequence numbers.
822 template <typename SequenceNumberT> class SequenceNumberManager {
823 public:
824   // Reset, making all sequence numbers available.
reset()825   void reset() {
826     std::lock_guard<std::mutex> Lock(SeqNoLock);
827     NextSequenceNumber = 0;
828     FreeSequenceNumbers.clear();
829   }
830 
831   // Get the next available sequence number. Will re-use numbers that have
832   // been released.
getSequenceNumber()833   SequenceNumberT getSequenceNumber() {
834     std::lock_guard<std::mutex> Lock(SeqNoLock);
835     if (FreeSequenceNumbers.empty())
836       return NextSequenceNumber++;
837     auto SequenceNumber = FreeSequenceNumbers.back();
838     FreeSequenceNumbers.pop_back();
839     return SequenceNumber;
840   }
841 
842   // Release a sequence number, making it available for re-use.
releaseSequenceNumber(SequenceNumberT SequenceNumber)843   void releaseSequenceNumber(SequenceNumberT SequenceNumber) {
844     std::lock_guard<std::mutex> Lock(SeqNoLock);
845     FreeSequenceNumbers.push_back(SequenceNumber);
846   }
847 
848 private:
849   std::mutex SeqNoLock;
850   SequenceNumberT NextSequenceNumber = 0;
851   std::vector<SequenceNumberT> FreeSequenceNumbers;
852 };
853 
854 // Checks that predicate P holds for each corresponding pair of type arguments
855 // from T1 and T2 tuple.
856 template <template <class, class> class P, typename T1Tuple, typename T2Tuple>
857 class RPCArgTypeCheckHelper;
858 
859 template <template <class, class> class P>
860 class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> {
861 public:
862   static const bool value = true;
863 };
864 
865 template <template <class, class> class P, typename T, typename... Ts,
866           typename U, typename... Us>
867 class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> {
868 public:
869   static const bool value =
870       P<T, U>::value &&
871       RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value;
872 };
873 
874 template <template <class, class> class P, typename T1Sig, typename T2Sig>
875 class RPCArgTypeCheck {
876 public:
877   using T1Tuple = typename FunctionArgsTuple<T1Sig>::Type;
878   using T2Tuple = typename FunctionArgsTuple<T2Sig>::Type;
879 
880   static_assert(std::tuple_size<T1Tuple>::value >=
881                     std::tuple_size<T2Tuple>::value,
882                 "Too many arguments to RPC call");
883   static_assert(std::tuple_size<T1Tuple>::value <=
884                     std::tuple_size<T2Tuple>::value,
885                 "Too few arguments to RPC call");
886 
887   static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value;
888 };
889 
890 template <typename ChannelT, typename WireT, typename ConcreteT>
891 class CanSerialize {
892 private:
893   using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
894 
895   template <typename T>
896   static std::true_type
897   check(typename std::enable_if<
898         std::is_same<decltype(T::serialize(std::declval<ChannelT &>(),
899                                            std::declval<const ConcreteT &>())),
900                      Error>::value,
901         void *>::type);
902 
903   template <typename> static std::false_type check(...);
904 
905 public:
906   static const bool value = decltype(check<S>(0))::value;
907 };
908 
909 template <typename ChannelT, typename WireT, typename ConcreteT>
910 class CanDeserialize {
911 private:
912   using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
913 
914   template <typename T>
915   static std::true_type
916   check(typename std::enable_if<
917         std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(),
918                                              std::declval<ConcreteT &>())),
919                      Error>::value,
920         void *>::type);
921 
922   template <typename> static std::false_type check(...);
923 
924 public:
925   static const bool value = decltype(check<S>(0))::value;
926 };
927 
928 /// Contains primitive utilities for defining, calling and handling calls to
929 /// remote procedures. ChannelT is a bidirectional stream conforming to the
930 /// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure
931 /// identifier type that must be serializable on ChannelT, and SequenceNumberT
932 /// is an integral type that will be used to number in-flight function calls.
933 ///
934 /// These utilities support the construction of very primitive RPC utilities.
935 /// Their intent is to ensure correct serialization and deserialization of
936 /// procedure arguments, and to keep the client and server's view of the API in
937 /// sync.
938 template <typename ImplT, typename ChannelT, typename FunctionIdT,
939           typename SequenceNumberT>
940 class RPCEndpointBase {
941 protected:
942   class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> {
943   public:
getName()944     static const char *getName() { return "__orc_rpc$invalid"; }
945   };
946 
947   class OrcRPCResponse : public Function<OrcRPCResponse, void()> {
948   public:
getName()949     static const char *getName() { return "__orc_rpc$response"; }
950   };
951 
952   class OrcRPCNegotiate
953       : public Function<OrcRPCNegotiate, FunctionIdT(std::string)> {
954   public:
getName()955     static const char *getName() { return "__orc_rpc$negotiate"; }
956   };
957 
958   // Helper predicate for testing for the presence of SerializeTraits
959   // serializers.
960   template <typename WireT, typename ConcreteT>
961   class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> {
962   public:
963     using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value;
964 
965     static_assert(value, "Missing serializer for argument (Can't serialize the "
966                          "first template type argument of CanSerializeCheck "
967                          "from the second)");
968   };
969 
970   // Helper predicate for testing for the presence of SerializeTraits
971   // deserializers.
972   template <typename WireT, typename ConcreteT>
973   class CanDeserializeCheck
974       : detail::CanDeserialize<ChannelT, WireT, ConcreteT> {
975   public:
976     using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value;
977 
978     static_assert(value, "Missing deserializer for argument (Can't deserialize "
979                          "the second template type argument of "
980                          "CanDeserializeCheck from the first)");
981   };
982 
983 public:
984   /// Construct an RPC instance on a channel.
RPCEndpointBase(ChannelT & C,bool LazyAutoNegotiation)985   RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation)
986       : C(C), LazyAutoNegotiation(LazyAutoNegotiation) {
987     // Hold ResponseId in a special variable, since we expect Response to be
988     // called relatively frequently, and want to avoid the map lookup.
989     ResponseId = FnIdAllocator.getResponseId();
990     RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId;
991 
992     // Register the negotiate function id and handler.
993     auto NegotiateId = FnIdAllocator.getNegotiateId();
994     RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId;
995     Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>(
996         [this](const std::string &Name) { return handleNegotiate(Name); });
997   }
998 
999 
1000   /// Negotiate a function id for Func with the other end of the channel.
1001   template <typename Func> Error negotiateFunction(bool Retry = false) {
1002     return getRemoteFunctionId<Func>(true, Retry).takeError();
1003   }
1004 
1005   /// Append a call Func, does not call send on the channel.
1006   /// The first argument specifies a user-defined handler to be run when the
1007   /// function returns. The handler should take an Expected<Func::ReturnType>,
1008   /// or an Error (if Func::ReturnType is void). The handler will be called
1009   /// with an error if the return value is abandoned due to a channel error.
1010   template <typename Func, typename HandlerT, typename... ArgTs>
appendCallAsync(HandlerT Handler,const ArgTs &...Args)1011   Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) {
1012 
1013     static_assert(
1014         detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type,
1015                                 void(ArgTs...)>::value,
1016         "");
1017 
1018     // Look up the function ID.
1019     FunctionIdT FnId;
1020     if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false))
1021       FnId = *FnIdOrErr;
1022     else {
1023       // Negotiation failed. Notify the handler then return the negotiate-failed
1024       // error.
1025       cantFail(Handler(make_error<ResponseAbandoned>()));
1026       return FnIdOrErr.takeError();
1027     }
1028 
1029     SequenceNumberT SeqNo; // initialized in locked scope below.
1030     {
1031       // Lock the pending responses map and sequence number manager.
1032       std::lock_guard<std::mutex> Lock(ResponsesMutex);
1033 
1034       // Allocate a sequence number.
1035       SeqNo = SequenceNumberMgr.getSequenceNumber();
1036       assert(!PendingResponses.count(SeqNo) &&
1037              "Sequence number already allocated");
1038 
1039       // Install the user handler.
1040       PendingResponses[SeqNo] =
1041         detail::createResponseHandler<ChannelT, typename Func::ReturnType>(
1042             std::move(Handler));
1043     }
1044 
1045     // Open the function call message.
1046     if (auto Err = C.startSendMessage(FnId, SeqNo)) {
1047       abandonPendingResponses();
1048       return Err;
1049     }
1050 
1051     // Serialize the call arguments.
1052     if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs(
1053             C, Args...)) {
1054       abandonPendingResponses();
1055       return Err;
1056     }
1057 
1058     // Close the function call messagee.
1059     if (auto Err = C.endSendMessage()) {
1060       abandonPendingResponses();
1061       return Err;
1062     }
1063 
1064     return Error::success();
1065   }
1066 
sendAppendedCalls()1067   Error sendAppendedCalls() { return C.send(); };
1068 
1069   template <typename Func, typename HandlerT, typename... ArgTs>
callAsync(HandlerT Handler,const ArgTs &...Args)1070   Error callAsync(HandlerT Handler, const ArgTs &... Args) {
1071     if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...))
1072       return Err;
1073     return C.send();
1074   }
1075 
1076   /// Handle one incoming call.
handleOne()1077   Error handleOne() {
1078     FunctionIdT FnId;
1079     SequenceNumberT SeqNo;
1080     if (auto Err = C.startReceiveMessage(FnId, SeqNo)) {
1081       abandonPendingResponses();
1082       return Err;
1083     }
1084     if (FnId == ResponseId)
1085       return handleResponse(SeqNo);
1086     auto I = Handlers.find(FnId);
1087     if (I != Handlers.end())
1088       return I->second(C, SeqNo);
1089 
1090     // else: No handler found. Report error to client?
1091     return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId,
1092                                                                      SeqNo);
1093   }
1094 
1095   /// Helper for handling setter procedures - this method returns a functor that
1096   /// sets the variables referred to by Args... to values deserialized from the
1097   /// channel.
1098   /// E.g.
1099   ///
1100   ///   typedef Function<0, bool, int> Func1;
1101   ///
1102   ///   ...
1103   ///   bool B;
1104   ///   int I;
1105   ///   if (auto Err = expect<Func1>(Channel, readArgs(B, I)))
1106   ///     /* Handle Args */ ;
1107   ///
1108   template <typename... ArgTs>
readArgs(ArgTs &...Args)1109   static detail::ReadArgs<ArgTs...> readArgs(ArgTs &... Args) {
1110     return detail::ReadArgs<ArgTs...>(Args...);
1111   }
1112 
1113   /// Abandon all outstanding result handlers.
1114   ///
1115   /// This will call all currently registered result handlers to receive an
1116   /// "abandoned" error as their argument. This is used internally by the RPC
1117   /// in error situations, but can also be called directly by clients who are
1118   /// disconnecting from the remote and don't or can't expect responses to their
1119   /// outstanding calls. (Especially for outstanding blocking calls, calling
1120   /// this function may be necessary to avoid dead threads).
abandonPendingResponses()1121   void abandonPendingResponses() {
1122     // Lock the pending responses map and sequence number manager.
1123     std::lock_guard<std::mutex> Lock(ResponsesMutex);
1124 
1125     for (auto &KV : PendingResponses)
1126       KV.second->abandon();
1127     PendingResponses.clear();
1128     SequenceNumberMgr.reset();
1129   }
1130 
1131   /// Remove the handler for the given function.
1132   /// A handler must currently be registered for this function.
1133   template <typename Func>
removeHandler()1134   void removeHandler() {
1135     auto IdItr = LocalFunctionIds.find(Func::getPrototype());
1136     assert(IdItr != LocalFunctionIds.end() &&
1137            "Function does not have a registered handler");
1138     auto HandlerItr = Handlers.find(IdItr->second);
1139     assert(HandlerItr != Handlers.end() &&
1140            "Function does not have a registered handler");
1141     Handlers.erase(HandlerItr);
1142   }
1143 
1144   /// Clear all handlers.
clearHandlers()1145   void clearHandlers() {
1146     Handlers.clear();
1147   }
1148 
1149 protected:
1150 
getInvalidFunctionId()1151   FunctionIdT getInvalidFunctionId() const {
1152     return FnIdAllocator.getInvalidId();
1153   }
1154 
1155   /// Add the given handler to the handler map and make it available for
1156   /// autonegotiation and execution.
1157   template <typename Func, typename HandlerT>
addHandlerImpl(HandlerT Handler)1158   void addHandlerImpl(HandlerT Handler) {
1159 
1160     static_assert(detail::RPCArgTypeCheck<
1161                       CanDeserializeCheck, typename Func::Type,
1162                       typename detail::HandlerTraits<HandlerT>::Type>::value,
1163                   "");
1164 
1165     FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1166     LocalFunctionIds[Func::getPrototype()] = NewFnId;
1167     Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler));
1168   }
1169 
1170   template <typename Func, typename HandlerT>
addAsyncHandlerImpl(HandlerT Handler)1171   void addAsyncHandlerImpl(HandlerT Handler) {
1172 
1173     static_assert(detail::RPCArgTypeCheck<
1174                       CanDeserializeCheck, typename Func::Type,
1175                       typename detail::AsyncHandlerTraits<
1176                         typename detail::HandlerTraits<HandlerT>::Type
1177                       >::Type>::value,
1178                   "");
1179 
1180     FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1181     LocalFunctionIds[Func::getPrototype()] = NewFnId;
1182     Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler));
1183   }
1184 
handleResponse(SequenceNumberT SeqNo)1185   Error handleResponse(SequenceNumberT SeqNo) {
1186     using Handler = typename decltype(PendingResponses)::mapped_type;
1187     Handler PRHandler;
1188 
1189     {
1190       // Lock the pending responses map and sequence number manager.
1191       std::unique_lock<std::mutex> Lock(ResponsesMutex);
1192       auto I = PendingResponses.find(SeqNo);
1193 
1194       if (I != PendingResponses.end()) {
1195         PRHandler = std::move(I->second);
1196         PendingResponses.erase(I);
1197         SequenceNumberMgr.releaseSequenceNumber(SeqNo);
1198       } else {
1199         // Unlock the pending results map to prevent recursive lock.
1200         Lock.unlock();
1201         abandonPendingResponses();
1202         return make_error<
1203                  InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo);
1204       }
1205     }
1206 
1207     assert(PRHandler &&
1208            "If we didn't find a response handler we should have bailed out");
1209 
1210     if (auto Err = PRHandler->handleResponse(C)) {
1211       abandonPendingResponses();
1212       return Err;
1213     }
1214 
1215     return Error::success();
1216   }
1217 
handleNegotiate(const std::string & Name)1218   FunctionIdT handleNegotiate(const std::string &Name) {
1219     auto I = LocalFunctionIds.find(Name);
1220     if (I == LocalFunctionIds.end())
1221       return getInvalidFunctionId();
1222     return I->second;
1223   }
1224 
1225   // Find the remote FunctionId for the given function.
1226   template <typename Func>
getRemoteFunctionId(bool NegotiateIfNotInMap,bool NegotiateIfInvalid)1227   Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap,
1228                                             bool NegotiateIfInvalid) {
1229     bool DoNegotiate;
1230 
1231     // Check if we already have a function id...
1232     auto I = RemoteFunctionIds.find(Func::getPrototype());
1233     if (I != RemoteFunctionIds.end()) {
1234       // If it's valid there's nothing left to do.
1235       if (I->second != getInvalidFunctionId())
1236         return I->second;
1237       DoNegotiate = NegotiateIfInvalid;
1238     } else
1239       DoNegotiate = NegotiateIfNotInMap;
1240 
1241     // We don't have a function id for Func yet, but we're allowed to try to
1242     // negotiate one.
1243     if (DoNegotiate) {
1244       auto &Impl = static_cast<ImplT &>(*this);
1245       if (auto RemoteIdOrErr =
1246           Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) {
1247         RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
1248         if (*RemoteIdOrErr == getInvalidFunctionId())
1249           return make_error<CouldNotNegotiate>(Func::getPrototype());
1250         return *RemoteIdOrErr;
1251       } else
1252         return RemoteIdOrErr.takeError();
1253     }
1254 
1255     // No key was available in the map and we weren't allowed to try to
1256     // negotiate one, so return an unknown function error.
1257     return make_error<CouldNotNegotiate>(Func::getPrototype());
1258   }
1259 
1260   using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>;
1261 
1262   // Wrap the given user handler in the necessary argument-deserialization code,
1263   // result-serialization code, and call to the launch policy (if present).
1264   template <typename Func, typename HandlerT>
wrapHandler(HandlerT Handler)1265   WrappedHandlerFn wrapHandler(HandlerT Handler) {
1266     return [this, Handler](ChannelT &Channel,
1267                            SequenceNumberT SeqNo) mutable -> Error {
1268       // Start by deserializing the arguments.
1269       using ArgsTuple =
1270           typename detail::FunctionArgsTuple<
1271             typename detail::HandlerTraits<HandlerT>::Type>::Type;
1272       auto Args = std::make_shared<ArgsTuple>();
1273 
1274       if (auto Err =
1275               detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1276                   Channel, *Args))
1277         return Err;
1278 
1279       // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1280       // for RPCArgs. Void cast RPCArgs to work around this for now.
1281       // FIXME: Remove this workaround once we can assume a working GCC version.
1282       (void)Args;
1283 
1284       // End receieve message, unlocking the channel for reading.
1285       if (auto Err = Channel.endReceiveMessage())
1286         return Err;
1287 
1288       using HTraits = detail::HandlerTraits<HandlerT>;
1289       using FuncReturn = typename Func::ReturnType;
1290       return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo,
1291                                          HTraits::unpackAndRun(Handler, *Args));
1292     };
1293   }
1294 
1295   // Wrap the given user handler in the necessary argument-deserialization code,
1296   // result-serialization code, and call to the launch policy (if present).
1297   template <typename Func, typename HandlerT>
wrapAsyncHandler(HandlerT Handler)1298   WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) {
1299     return [this, Handler](ChannelT &Channel,
1300                            SequenceNumberT SeqNo) mutable -> Error {
1301       // Start by deserializing the arguments.
1302       using AHTraits = detail::AsyncHandlerTraits<
1303                          typename detail::HandlerTraits<HandlerT>::Type>;
1304       using ArgsTuple =
1305           typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type;
1306       auto Args = std::make_shared<ArgsTuple>();
1307 
1308       if (auto Err =
1309               detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1310                   Channel, *Args))
1311         return Err;
1312 
1313       // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1314       // for RPCArgs. Void cast RPCArgs to work around this for now.
1315       // FIXME: Remove this workaround once we can assume a working GCC version.
1316       (void)Args;
1317 
1318       // End receieve message, unlocking the channel for reading.
1319       if (auto Err = Channel.endReceiveMessage())
1320         return Err;
1321 
1322       using HTraits = detail::HandlerTraits<HandlerT>;
1323       using FuncReturn = typename Func::ReturnType;
1324       auto Responder =
1325         [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error {
1326           return detail::respond<FuncReturn>(C, ResponseId, SeqNo,
1327                                              std::move(RetVal));
1328         };
1329 
1330       return HTraits::unpackAndRunAsync(Handler, Responder, *Args);
1331     };
1332   }
1333 
1334   ChannelT &C;
1335 
1336   bool LazyAutoNegotiation;
1337 
1338   RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator;
1339 
1340   FunctionIdT ResponseId;
1341   std::map<std::string, FunctionIdT> LocalFunctionIds;
1342   std::map<const char *, FunctionIdT> RemoteFunctionIds;
1343 
1344   std::map<FunctionIdT, WrappedHandlerFn> Handlers;
1345 
1346   std::mutex ResponsesMutex;
1347   detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr;
1348   std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>>
1349       PendingResponses;
1350 };
1351 
1352 } // end namespace detail
1353 
1354 template <typename ChannelT, typename FunctionIdT = uint32_t,
1355           typename SequenceNumberT = uint32_t>
1356 class MultiThreadedRPCEndpoint
1357     : public detail::RPCEndpointBase<
1358           MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1359           ChannelT, FunctionIdT, SequenceNumberT> {
1360 private:
1361   using BaseClass =
1362       detail::RPCEndpointBase<
1363         MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1364         ChannelT, FunctionIdT, SequenceNumberT>;
1365 
1366 public:
MultiThreadedRPCEndpoint(ChannelT & C,bool LazyAutoNegotiation)1367   MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1368       : BaseClass(C, LazyAutoNegotiation) {}
1369 
1370   /// Add a handler for the given RPC function.
1371   /// This installs the given handler functor for the given RPC Function, and
1372   /// makes the RPC function available for negotiation/calling from the remote.
1373   template <typename Func, typename HandlerT>
addHandler(HandlerT Handler)1374   void addHandler(HandlerT Handler) {
1375     return this->template addHandlerImpl<Func>(std::move(Handler));
1376   }
1377 
1378   /// Add a class-method as a handler.
1379   template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1380   void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1381     addHandler<Func>(
1382       detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1383   }
1384 
1385   template <typename Func, typename HandlerT>
addAsyncHandler(HandlerT Handler)1386   void addAsyncHandler(HandlerT Handler) {
1387     return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1388   }
1389 
1390   /// Add a class-method as a handler.
1391   template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addAsyncHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1392   void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1393     addAsyncHandler<Func>(
1394       detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1395   }
1396 
1397   /// Return type for non-blocking call primitives.
1398   template <typename Func>
1399   using NonBlockingCallResult = typename detail::ResultTraits<
1400       typename Func::ReturnType>::ReturnFutureType;
1401 
1402   /// Call Func on Channel C. Does not block, does not call send. Returns a pair
1403   /// of a future result and the sequence number assigned to the result.
1404   ///
1405   /// This utility function is primarily used for single-threaded mode support,
1406   /// where the sequence number can be used to wait for the corresponding
1407   /// result. In multi-threaded mode the appendCallNB method, which does not
1408   /// return the sequence numeber, should be preferred.
1409   template <typename Func, typename... ArgTs>
appendCallNB(const ArgTs &...Args)1410   Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &... Args) {
1411     using RTraits = detail::ResultTraits<typename Func::ReturnType>;
1412     using ErrorReturn = typename RTraits::ErrorReturnType;
1413     using ErrorReturnPromise = typename RTraits::ReturnPromiseType;
1414 
1415     // FIXME: Stack allocate and move this into the handler once LLVM builds
1416     //        with C++14.
1417     auto Promise = std::make_shared<ErrorReturnPromise>();
1418     auto FutureResult = Promise->get_future();
1419 
1420     if (auto Err = this->template appendCallAsync<Func>(
1421             [Promise](ErrorReturn RetOrErr) {
1422               Promise->set_value(std::move(RetOrErr));
1423               return Error::success();
1424             },
1425             Args...)) {
1426       RTraits::consumeAbandoned(FutureResult.get());
1427       return std::move(Err);
1428     }
1429     return std::move(FutureResult);
1430   }
1431 
1432   /// The same as appendCallNBWithSeq, except that it calls C.send() to
1433   /// flush the channel after serializing the call.
1434   template <typename Func, typename... ArgTs>
callNB(const ArgTs &...Args)1435   Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &... Args) {
1436     auto Result = appendCallNB<Func>(Args...);
1437     if (!Result)
1438       return Result;
1439     if (auto Err = this->C.send()) {
1440       this->abandonPendingResponses();
1441       detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1442           std::move(Result->get()));
1443       return std::move(Err);
1444     }
1445     return Result;
1446   }
1447 
1448   /// Call Func on Channel C. Blocks waiting for a result. Returns an Error
1449   /// for void functions or an Expected<T> for functions returning a T.
1450   ///
1451   /// This function is for use in threaded code where another thread is
1452   /// handling responses and incoming calls.
1453   template <typename Func, typename... ArgTs,
1454             typename AltRetT = typename Func::ReturnType>
1455   typename detail::ResultTraits<AltRetT>::ErrorReturnType
callB(const ArgTs &...Args)1456   callB(const ArgTs &... Args) {
1457     if (auto FutureResOrErr = callNB<Func>(Args...))
1458       return FutureResOrErr->get();
1459     else
1460       return FutureResOrErr.takeError();
1461   }
1462 
1463   /// Handle incoming RPC calls.
handlerLoop()1464   Error handlerLoop() {
1465     while (true)
1466       if (auto Err = this->handleOne())
1467         return Err;
1468     return Error::success();
1469   }
1470 };
1471 
1472 template <typename ChannelT, typename FunctionIdT = uint32_t,
1473           typename SequenceNumberT = uint32_t>
1474 class SingleThreadedRPCEndpoint
1475     : public detail::RPCEndpointBase<
1476           SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1477           ChannelT, FunctionIdT, SequenceNumberT> {
1478 private:
1479   using BaseClass =
1480       detail::RPCEndpointBase<
1481         SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1482         ChannelT, FunctionIdT, SequenceNumberT>;
1483 
1484 public:
SingleThreadedRPCEndpoint(ChannelT & C,bool LazyAutoNegotiation)1485   SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1486       : BaseClass(C, LazyAutoNegotiation) {}
1487 
1488   template <typename Func, typename HandlerT>
addHandler(HandlerT Handler)1489   void addHandler(HandlerT Handler) {
1490     return this->template addHandlerImpl<Func>(std::move(Handler));
1491   }
1492 
1493   template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1494   void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1495     addHandler<Func>(
1496         detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1497   }
1498 
1499   template <typename Func, typename HandlerT>
addAsyncHandler(HandlerT Handler)1500   void addAsyncHandler(HandlerT Handler) {
1501     return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1502   }
1503 
1504   /// Add a class-method as a handler.
1505   template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addAsyncHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1506   void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1507     addAsyncHandler<Func>(
1508       detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1509   }
1510 
1511   template <typename Func, typename... ArgTs,
1512             typename AltRetT = typename Func::ReturnType>
1513   typename detail::ResultTraits<AltRetT>::ErrorReturnType
callB(const ArgTs &...Args)1514   callB(const ArgTs &... Args) {
1515     bool ReceivedResponse = false;
1516     using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType;
1517     auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue();
1518 
1519     // We have to 'Check' result (which we know is in a success state at this
1520     // point) so that it can be overwritten in the async handler.
1521     (void)!!Result;
1522 
1523     if (auto Err = this->template appendCallAsync<Func>(
1524             [&](ResultType R) {
1525               Result = std::move(R);
1526               ReceivedResponse = true;
1527               return Error::success();
1528             },
1529             Args...)) {
1530       detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1531           std::move(Result));
1532       return std::move(Err);
1533     }
1534 
1535     while (!ReceivedResponse) {
1536       if (auto Err = this->handleOne()) {
1537         detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1538             std::move(Result));
1539         return std::move(Err);
1540       }
1541     }
1542 
1543     return Result;
1544   }
1545 };
1546 
1547 /// Asynchronous dispatch for a function on an RPC endpoint.
1548 template <typename RPCClass, typename Func>
1549 class RPCAsyncDispatch {
1550 public:
RPCAsyncDispatch(RPCClass & Endpoint)1551   RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {}
1552 
1553   template <typename HandlerT, typename... ArgTs>
operator()1554   Error operator()(HandlerT Handler, const ArgTs &... Args) const {
1555     return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...);
1556   }
1557 
1558 private:
1559   RPCClass &Endpoint;
1560 };
1561 
1562 /// Construct an asynchronous dispatcher from an RPC endpoint and a Func.
1563 template <typename Func, typename RPCEndpointT>
rpcAsyncDispatch(RPCEndpointT & Endpoint)1564 RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) {
1565   return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint);
1566 }
1567 
1568 /// Allows a set of asynchrounous calls to be dispatched, and then
1569 ///        waited on as a group.
1570 class ParallelCallGroup {
1571 public:
1572 
1573   ParallelCallGroup() = default;
1574   ParallelCallGroup(const ParallelCallGroup &) = delete;
1575   ParallelCallGroup &operator=(const ParallelCallGroup &) = delete;
1576 
1577   /// Make as asynchronous call.
1578   template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs>
call(const AsyncDispatcher & AsyncDispatch,HandlerT Handler,const ArgTs &...Args)1579   Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler,
1580              const ArgTs &... Args) {
1581     // Increment the count of outstanding calls. This has to happen before
1582     // we invoke the call, as the handler may (depending on scheduling)
1583     // be run immediately on another thread, and we don't want the decrement
1584     // in the wrapped handler below to run before the increment.
1585     {
1586       std::unique_lock<std::mutex> Lock(M);
1587       ++NumOutstandingCalls;
1588     }
1589 
1590     // Wrap the user handler in a lambda that will decrement the
1591     // outstanding calls count, then poke the condition variable.
1592     using ArgType = typename detail::ResponseHandlerArg<
1593         typename detail::HandlerTraits<HandlerT>::Type>::ArgType;
1594     // FIXME: Move handler into wrapped handler once we have C++14.
1595     auto WrappedHandler = [this, Handler](ArgType Arg) {
1596       auto Err = Handler(std::move(Arg));
1597       std::unique_lock<std::mutex> Lock(M);
1598       --NumOutstandingCalls;
1599       CV.notify_all();
1600       return Err;
1601     };
1602 
1603     return AsyncDispatch(std::move(WrappedHandler), Args...);
1604   }
1605 
1606   /// Blocks until all calls have been completed and their return value
1607   ///        handlers run.
wait()1608   void wait() {
1609     std::unique_lock<std::mutex> Lock(M);
1610     while (NumOutstandingCalls > 0)
1611       CV.wait(Lock);
1612   }
1613 
1614 private:
1615   std::mutex M;
1616   std::condition_variable CV;
1617   uint32_t NumOutstandingCalls = 0;
1618 };
1619 
1620 /// Convenience class for grouping RPC Functions into APIs that can be
1621 ///        negotiated as a block.
1622 ///
1623 template <typename... Funcs>
1624 class APICalls {
1625 public:
1626 
1627   /// Test whether this API contains Function F.
1628   template <typename F>
1629   class Contains {
1630   public:
1631     static const bool value = false;
1632   };
1633 
1634   /// Negotiate all functions in this API.
1635   template <typename RPCEndpoint>
negotiate(RPCEndpoint & R)1636   static Error negotiate(RPCEndpoint &R) {
1637     return Error::success();
1638   }
1639 };
1640 
1641 template <typename Func, typename... Funcs>
1642 class APICalls<Func, Funcs...> {
1643 public:
1644 
1645   template <typename F>
1646   class Contains {
1647   public:
1648     static const bool value = std::is_same<F, Func>::value |
1649                               APICalls<Funcs...>::template Contains<F>::value;
1650   };
1651 
1652   template <typename RPCEndpoint>
negotiate(RPCEndpoint & R)1653   static Error negotiate(RPCEndpoint &R) {
1654     if (auto Err = R.template negotiateFunction<Func>())
1655       return Err;
1656     return APICalls<Funcs...>::negotiate(R);
1657   }
1658 
1659 };
1660 
1661 template <typename... InnerFuncs, typename... Funcs>
1662 class APICalls<APICalls<InnerFuncs...>, Funcs...> {
1663 public:
1664 
1665   template <typename F>
1666   class Contains {
1667   public:
1668     static const bool value =
1669       APICalls<InnerFuncs...>::template Contains<F>::value |
1670       APICalls<Funcs...>::template Contains<F>::value;
1671   };
1672 
1673   template <typename RPCEndpoint>
negotiate(RPCEndpoint & R)1674   static Error negotiate(RPCEndpoint &R) {
1675     if (auto Err = APICalls<InnerFuncs...>::negotiate(R))
1676       return Err;
1677     return APICalls<Funcs...>::negotiate(R);
1678   }
1679 
1680 };
1681 
1682 } // end namespace rpc
1683 } // end namespace orc
1684 } // end namespace llvm
1685 
1686 #endif
1687