1 // Copyright (c) 2020 Cloudflare, Inc. and contributors
2 // Licensed under the MIT License:
3 //
4 // Permission is hereby granted, free of charge, to any person obtaining a copy
5 // of this software and associated documentation files (the "Software"), to deal
6 // in the Software without restriction, including without limitation the rights
7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 // copies of the Software, and to permit persons to whom the Software is
9 // furnished to do so, subject to the following conditions:
10 //
11 // The above copyright notice and this permission notice shall be included in
12 // all copies or substantial portions of the Software.
13 //
14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 // THE SOFTWARE.
21
22 #include "reconnect.h"
23
24 namespace capnp {
25
26 namespace {
27
28 class ReconnectHook final: public ClientHook, public kj::Refcounted {
29 public:
ReconnectHook(kj::Function<Capability::Client ()> connectParam,bool lazy=false)30 ReconnectHook(kj::Function<Capability::Client()> connectParam, bool lazy = false)
31 : connect(kj::mv(connectParam)),
32 current(lazy ? kj::Maybe<kj::Own<ClientHook>>() : ClientHook::from(connect())) {}
33
newCall(uint64_t interfaceId,uint16_t methodId,kj::Maybe<MessageSize> sizeHint)34 Request<AnyPointer, AnyPointer> newCall(
35 uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
36 auto result = getCurrent().newCall(interfaceId, methodId, sizeHint);
37 AnyPointer::Builder builder = result;
38 auto hook = kj::heap<RequestImpl>(kj::addRef(*this), RequestHook::from(kj::mv(result)));
39 return { builder, kj::mv(hook) };
40 }
41
call(uint64_t interfaceId,uint16_t methodId,kj::Own<CallContextHook> && context)42 VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
43 kj::Own<CallContextHook>&& context) override {
44 auto result = getCurrent().call(interfaceId, methodId, kj::mv(context));
45 wrap(result.promise);
46 return result;
47 }
48
getResolved()49 kj::Maybe<ClientHook&> getResolved() override {
50 // We can't let people resolve to the underlying capability because then we wouldn't be able
51 // to redirect them later.
52 return nullptr;
53 }
54
whenMoreResolved()55 kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override {
56 return nullptr;
57 }
58
addRef()59 kj::Own<ClientHook> addRef() override {
60 return kj::addRef(*this);
61 }
62
getBrand()63 const void* getBrand() override {
64 return nullptr;
65 }
66
getFd()67 kj::Maybe<int> getFd() override {
68 // It's not safe to return current->getFd() because normally callers wouldn't expect the FD to
69 // change or go away over time, but this one could whenever we reconnect. If there's a use
70 // case for being able to access the FD here, we'll need a different interface to do it.
71 return nullptr;
72 }
73
74 private:
75 kj::Function<Capability::Client()> connect;
76 kj::Maybe<kj::Own<ClientHook>> current;
77 uint generation = 0;
78
79 template <typename T>
wrap(kj::Promise<T> & promise)80 void wrap(kj::Promise<T>& promise) {
81 promise = promise.catch_(
82 [self = kj::addRef(*this), startGeneration = generation]
83 (kj::Exception&& exception) mutable -> kj::Promise<T> {
84 if (exception.getType() == kj::Exception::Type::DISCONNECTED &&
85 self->generation == startGeneration) {
86 self->generation++;
87 KJ_IF_MAYBE(e2, kj::runCatchingExceptions([&]() {
88 self->current = ClientHook::from(self->connect());
89 })) {
90 self->current = newBrokenCap(kj::mv(*e2));
91 }
92 }
93 return kj::mv(exception);
94 });
95 }
96
getCurrent()97 ClientHook& getCurrent() {
98 KJ_IF_MAYBE(c, current) {
99 return **c;
100 } else {
101 return *current.emplace(ClientHook::from(connect()));
102 }
103 }
104
105 class RequestImpl final: public RequestHook {
106 public:
RequestImpl(kj::Own<ReconnectHook> parent,kj::Own<RequestHook> inner)107 RequestImpl(kj::Own<ReconnectHook> parent, kj::Own<RequestHook> inner)
108 : parent(kj::mv(parent)), inner(kj::mv(inner)) {}
109
send()110 RemotePromise<AnyPointer> send() override {
111 auto result = inner->send();
112 parent->wrap(result);
113 return result;
114 }
115
sendStreaming()116 kj::Promise<void> sendStreaming() override {
117 auto result = inner->sendStreaming();
118 parent->wrap(result);
119 return result;
120 }
121
getBrand()122 const void* getBrand() override {
123 return nullptr;
124 }
125
126 private:
127 kj::Own<ReconnectHook> parent;
128 kj::Own<RequestHook> inner;
129 };
130 };
131
132 } // namespace
133
autoReconnect(kj::Function<Capability::Client ()> connect)134 Capability::Client autoReconnect(kj::Function<Capability::Client()> connect) {
135 return Capability::Client(kj::refcounted<ReconnectHook>(kj::mv(connect)));
136 }
137
lazyAutoReconnect(kj::Function<Capability::Client ()> connect)138 Capability::Client lazyAutoReconnect(kj::Function<Capability::Client()> connect) {
139 return Capability::Client(kj::refcounted<ReconnectHook>(kj::mv(connect), true));
140 }
141 } // namespace capnp
142