1 // Copyright (c) 2020 Cloudflare, Inc. and contributors
2 // Licensed under the MIT License:
3 //
4 // Permission is hereby granted, free of charge, to any person obtaining a copy
5 // of this software and associated documentation files (the "Software"), to deal
6 // in the Software without restriction, including without limitation the rights
7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 // copies of the Software, and to permit persons to whom the Software is
9 // furnished to do so, subject to the following conditions:
10 //
11 // The above copyright notice and this permission notice shall be included in
12 // all copies or substantial portions of the Software.
13 //
14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 // THE SOFTWARE.
21 
22 #include "reconnect.h"
23 
24 namespace capnp {
25 
26 namespace {
27 
28 class ReconnectHook final: public ClientHook, public kj::Refcounted {
29 public:
ReconnectHook(kj::Function<Capability::Client ()> connectParam,bool lazy=false)30   ReconnectHook(kj::Function<Capability::Client()> connectParam, bool lazy = false)
31       : connect(kj::mv(connectParam)),
32         current(lazy ? kj::Maybe<kj::Own<ClientHook>>() : ClientHook::from(connect())) {}
33 
newCall(uint64_t interfaceId,uint16_t methodId,kj::Maybe<MessageSize> sizeHint)34   Request<AnyPointer, AnyPointer> newCall(
35       uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
36     auto result = getCurrent().newCall(interfaceId, methodId, sizeHint);
37     AnyPointer::Builder builder = result;
38     auto hook = kj::heap<RequestImpl>(kj::addRef(*this), RequestHook::from(kj::mv(result)));
39     return { builder, kj::mv(hook) };
40   }
41 
call(uint64_t interfaceId,uint16_t methodId,kj::Own<CallContextHook> && context)42   VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
43                               kj::Own<CallContextHook>&& context) override {
44     auto result = getCurrent().call(interfaceId, methodId, kj::mv(context));
45     wrap(result.promise);
46     return result;
47   }
48 
getResolved()49   kj::Maybe<ClientHook&> getResolved() override {
50     // We can't let people resolve to the underlying capability because then we wouldn't be able
51     // to redirect them later.
52     return nullptr;
53   }
54 
whenMoreResolved()55   kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override {
56     return nullptr;
57   }
58 
addRef()59   kj::Own<ClientHook> addRef() override {
60     return kj::addRef(*this);
61   }
62 
getBrand()63   const void* getBrand() override {
64     return nullptr;
65   }
66 
getFd()67   kj::Maybe<int> getFd() override {
68     // It's not safe to return current->getFd() because normally callers wouldn't expect the FD to
69     // change or go away over time, but this one could whenever we reconnect. If there's a use
70     // case for being able to access the FD here, we'll need a different interface to do it.
71     return nullptr;
72   }
73 
74 private:
75   kj::Function<Capability::Client()> connect;
76   kj::Maybe<kj::Own<ClientHook>> current;
77   uint generation = 0;
78 
79   template <typename T>
wrap(kj::Promise<T> & promise)80   void wrap(kj::Promise<T>& promise) {
81     promise = promise.catch_(
82         [self = kj::addRef(*this), startGeneration = generation]
83         (kj::Exception&& exception) mutable -> kj::Promise<T> {
84       if (exception.getType() == kj::Exception::Type::DISCONNECTED &&
85           self->generation == startGeneration) {
86         self->generation++;
87         KJ_IF_MAYBE(e2, kj::runCatchingExceptions([&]() {
88           self->current = ClientHook::from(self->connect());
89         })) {
90           self->current = newBrokenCap(kj::mv(*e2));
91         }
92       }
93       return kj::mv(exception);
94     });
95   }
96 
getCurrent()97   ClientHook& getCurrent() {
98     KJ_IF_MAYBE(c, current) {
99       return **c;
100     } else {
101       return *current.emplace(ClientHook::from(connect()));
102     }
103   }
104 
105   class RequestImpl final: public RequestHook {
106   public:
RequestImpl(kj::Own<ReconnectHook> parent,kj::Own<RequestHook> inner)107     RequestImpl(kj::Own<ReconnectHook> parent, kj::Own<RequestHook> inner)
108         : parent(kj::mv(parent)), inner(kj::mv(inner)) {}
109 
send()110     RemotePromise<AnyPointer> send() override {
111       auto result = inner->send();
112       parent->wrap(result);
113       return result;
114     }
115 
sendStreaming()116     kj::Promise<void> sendStreaming() override {
117       auto result = inner->sendStreaming();
118       parent->wrap(result);
119       return result;
120     }
121 
getBrand()122     const void* getBrand() override {
123       return nullptr;
124     }
125 
126   private:
127     kj::Own<ReconnectHook> parent;
128     kj::Own<RequestHook> inner;
129   };
130 };
131 
132 }  // namespace
133 
autoReconnect(kj::Function<Capability::Client ()> connect)134 Capability::Client autoReconnect(kj::Function<Capability::Client()> connect) {
135   return Capability::Client(kj::refcounted<ReconnectHook>(kj::mv(connect)));
136 }
137 
lazyAutoReconnect(kj::Function<Capability::Client ()> connect)138 Capability::Client lazyAutoReconnect(kj::Function<Capability::Client()> connect) {
139   return Capability::Client(kj::refcounted<ReconnectHook>(kj::mv(connect), true));
140 }
141 }  // namespace capnp
142