1 // Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
2 // Licensed under the MIT License:
3 //
4 // Permission is hereby granted, free of charge, to any person obtaining a copy
5 // of this software and associated documentation files (the "Software"), to deal
6 // in the Software without restriction, including without limitation the rights
7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 // copies of the Software, and to permit persons to whom the Software is
9 // furnished to do so, subject to the following conditions:
10 //
11 // The above copyright notice and this permission notice shall be included in
12 // all copies or substantial portions of the Software.
13 //
14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 // THE SOFTWARE.
21 
22 #include "rpc-twoparty.h"
23 #include "serialize-async.h"
24 #include <kj/debug.h>
25 #include <kj/io.h>
26 
27 namespace capnp {
28 
TwoPartyVatNetwork(kj::OneOf<MessageStream *,kj::Own<MessageStream>> && stream,uint maxFdsPerMessage,rpc::twoparty::Side side,ReaderOptions receiveOptions,const kj::MonotonicClock & clock)29 TwoPartyVatNetwork::TwoPartyVatNetwork(
30     kj::OneOf<MessageStream*, kj::Own<MessageStream>>&& stream,
31     uint maxFdsPerMessage,
32     rpc::twoparty::Side side,
33     ReaderOptions receiveOptions,
34     const kj::MonotonicClock& clock)
35 
36     : stream(kj::mv(stream)),
37       maxFdsPerMessage(maxFdsPerMessage),
38       side(side),
39       peerVatId(4),
40       receiveOptions(receiveOptions),
41       previousWrite(kj::READY_NOW),
42       clock(clock),
43       currentOutgoingMessageSendTime(clock.now()) {
44   peerVatId.initRoot<rpc::twoparty::VatId>().setSide(
45       side == rpc::twoparty::Side::CLIENT ? rpc::twoparty::Side::SERVER
46                                           : rpc::twoparty::Side::CLIENT);
47 
48   auto paf = kj::newPromiseAndFulfiller<void>();
49   disconnectPromise = paf.promise.fork();
50   disconnectFulfiller.fulfiller = kj::mv(paf.fulfiller);
51 }
52 
TwoPartyVatNetwork(capnp::MessageStream & stream,rpc::twoparty::Side side,ReaderOptions receiveOptions,const kj::MonotonicClock & clock)53 TwoPartyVatNetwork::TwoPartyVatNetwork(capnp::MessageStream& stream,
54                    rpc::twoparty::Side side, ReaderOptions receiveOptions,
55                    const kj::MonotonicClock& clock)
56   : TwoPartyVatNetwork(stream, 0, side, receiveOptions, clock) {}
57 
TwoPartyVatNetwork(capnp::MessageStream & stream,uint maxFdsPerMessage,rpc::twoparty::Side side,ReaderOptions receiveOptions,const kj::MonotonicClock & clock)58 TwoPartyVatNetwork::TwoPartyVatNetwork(
59     capnp::MessageStream& stream,
60     uint maxFdsPerMessage,
61     rpc::twoparty::Side side,
62     ReaderOptions receiveOptions,
63     const kj::MonotonicClock& clock)
64     : TwoPartyVatNetwork(&stream, maxFdsPerMessage, side, receiveOptions, clock) {}
65 
TwoPartyVatNetwork(kj::AsyncIoStream & stream,rpc::twoparty::Side side,ReaderOptions receiveOptions,const kj::MonotonicClock & clock)66 TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side,
67                                        ReaderOptions receiveOptions,
68                                        const kj::MonotonicClock& clock)
69     : TwoPartyVatNetwork(kj::Own<MessageStream>(kj::heap<AsyncIoMessageStream>(stream)),
70                          0, side, receiveOptions, clock) {}
71 
TwoPartyVatNetwork(kj::AsyncCapabilityStream & stream,uint maxFdsPerMessage,rpc::twoparty::Side side,ReaderOptions receiveOptions,const kj::MonotonicClock & clock)72 TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncCapabilityStream& stream, uint maxFdsPerMessage,
73                                        rpc::twoparty::Side side, ReaderOptions receiveOptions,
74                                        const kj::MonotonicClock& clock)
75     : TwoPartyVatNetwork(kj::Own<MessageStream>(kj::heap<AsyncCapabilityMessageStream>(stream)),
76                          maxFdsPerMessage, side, receiveOptions, clock) {}
77 
getStream()78 MessageStream& TwoPartyVatNetwork::getStream() {
79   KJ_SWITCH_ONEOF(stream) {
80     KJ_CASE_ONEOF(s, MessageStream*) {
81       return *s;
82     }
83     KJ_CASE_ONEOF(s, kj::Own<MessageStream>) {
84       return *s;
85     }
86   }
87   KJ_UNREACHABLE;
88 }
89 
disposeImpl(void * pointer) const90 void TwoPartyVatNetwork::FulfillerDisposer::disposeImpl(void* pointer) const {
91   if (--refcount == 0) {
92     fulfiller->fulfill();
93   }
94 }
95 
asConnection()96 kj::Own<TwoPartyVatNetworkBase::Connection> TwoPartyVatNetwork::asConnection() {
97   ++disconnectFulfiller.refcount;
98   return kj::Own<TwoPartyVatNetworkBase::Connection>(this, disconnectFulfiller);
99 }
100 
connect(rpc::twoparty::VatId::Reader ref)101 kj::Maybe<kj::Own<TwoPartyVatNetworkBase::Connection>> TwoPartyVatNetwork::connect(
102     rpc::twoparty::VatId::Reader ref) {
103   if (ref.getSide() == side) {
104     return nullptr;
105   } else {
106     return asConnection();
107   }
108 }
109 
accept()110 kj::Promise<kj::Own<TwoPartyVatNetworkBase::Connection>> TwoPartyVatNetwork::accept() {
111   if (side == rpc::twoparty::Side::SERVER && !accepted) {
112     accepted = true;
113     return asConnection();
114   } else {
115     // Create a promise that will never be fulfilled.
116     auto paf = kj::newPromiseAndFulfiller<kj::Own<TwoPartyVatNetworkBase::Connection>>();
117     acceptFulfiller = kj::mv(paf.fulfiller);
118     return kj::mv(paf.promise);
119   }
120 }
121 
122 class TwoPartyVatNetwork::OutgoingMessageImpl final
123     : public OutgoingRpcMessage, public kj::Refcounted {
124 public:
OutgoingMessageImpl(TwoPartyVatNetwork & network,uint firstSegmentWordSize)125   OutgoingMessageImpl(TwoPartyVatNetwork& network, uint firstSegmentWordSize)
126       : network(network),
127         message(firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize) {}
128 
getBody()129   AnyPointer::Builder getBody() override {
130     return message.getRoot<AnyPointer>();
131   }
132 
setFds(kj::Array<int> fds)133   void setFds(kj::Array<int> fds) override {
134     if (network.maxFdsPerMessage > 0) {
135       this->fds = kj::mv(fds);
136     }
137   }
138 
send()139   void send() override {
140     size_t size = 0;
141     for (auto& segment: message.getSegmentsForOutput()) {
142       size += segment.size();
143     }
144     KJ_REQUIRE(size < network.receiveOptions.traversalLimitInWords, size,
145                "Trying to send Cap'n Proto message larger than our single-message size limit. The "
146                "other side probably won't accept it (assuming its traversalLimitInWords matches "
147                "ours) and would abort the connection, so I won't send it.") {
148       return;
149     }
150 
151     network.currentQueueSize += size * sizeof(capnp::word);
152     ++network.currentQueueCount;
153     auto deferredSizeUpdate = kj::defer([&network = network, size]() mutable {
154       network.currentQueueSize -= size * sizeof(capnp::word);
155       --network.currentQueueCount;
156     });
157 
158     auto sendTime = network.clock.now();
159     network.previousWrite = KJ_ASSERT_NONNULL(network.previousWrite, "already shut down")
160         .then([this, sendTime]() {
161       return kj::evalNow([&]() {
162         network.currentOutgoingMessageSendTime = sendTime;
163         return network.getStream().writeMessage(fds, message);
164       }).catch_([this](kj::Exception&& e) {
165         // Since no one checks write failures, we need to propagate them into read failures,
166         // otherwise we might get stuck sending all messages into a black hole and wondering why
167         // the peer never replies.
168         network.readCancelReason = kj::cp(e);
169         if (!network.readCanceler.isEmpty()) {
170           network.readCanceler.cancel(kj::cp(e));
171         }
172         kj::throwRecoverableException(kj::mv(e));
173       });
174     }).attach(kj::addRef(*this), kj::mv(deferredSizeUpdate))
175       // Note that it's important that the eagerlyEvaluate() come *after* the attach() because
176       // otherwise the message (and any capabilities in it) will not be released until a new
177       // message is written! (Kenton once spent all afternoon tracking this down...)
178       .eagerlyEvaluate(nullptr);
179   }
180 
sizeInWords()181   size_t sizeInWords() override {
182     return message.sizeInWords();
183   }
184 
185 private:
186   TwoPartyVatNetwork& network;
187   MallocMessageBuilder message;
188   kj::Array<int> fds;
189 };
190 
getOutgoingMessageWaitTime()191 kj::Duration TwoPartyVatNetwork::getOutgoingMessageWaitTime() {
192   if (currentQueueCount > 0) {
193     return clock.now() - currentOutgoingMessageSendTime;
194   } else {
195     return 0 * kj::SECONDS;
196   }
197 }
198 
199 class TwoPartyVatNetwork::IncomingMessageImpl final: public IncomingRpcMessage {
200 public:
IncomingMessageImpl(kj::Own<MessageReader> message)201   IncomingMessageImpl(kj::Own<MessageReader> message): message(kj::mv(message)) {}
202 
IncomingMessageImpl(MessageReaderAndFds init,kj::Array<kj::AutoCloseFd> fdSpace)203   IncomingMessageImpl(MessageReaderAndFds init, kj::Array<kj::AutoCloseFd> fdSpace)
204       : message(kj::mv(init.reader)),
205         fdSpace(kj::mv(fdSpace)),
206         fds(init.fds) {
207     KJ_DASSERT(this->fds.begin() == this->fdSpace.begin());
208   }
209 
getBody()210   AnyPointer::Reader getBody() override {
211     return message->getRoot<AnyPointer>();
212   }
213 
getAttachedFds()214   kj::ArrayPtr<kj::AutoCloseFd> getAttachedFds() override {
215     return fds;
216   }
217 
sizeInWords()218   size_t sizeInWords() override {
219     return message->sizeInWords();
220   }
221 
222 private:
223   kj::Own<MessageReader> message;
224   kj::Array<kj::AutoCloseFd> fdSpace;
225   kj::ArrayPtr<kj::AutoCloseFd> fds;
226 };
227 
newStream()228 kj::Own<RpcFlowController> TwoPartyVatNetwork::newStream() {
229   return RpcFlowController::newVariableWindowController(*this);
230 }
231 
getWindow()232 size_t TwoPartyVatNetwork::getWindow() {
233   // The socket's send buffer size -- as returned by getsockopt(SO_SNDBUF) -- tells us how much
234   // data the kernel itself is willing to buffer. The kernel will increase the send buffer size if
235   // needed to fill the connection's congestion window. So we can cheat and use it as our stream
236   // window, too, to make sure we saturate said congestion window.
237   //
238   // TODO(perf): Unfortunately, this hack breaks down in the presence of proxying. What we really
239   //   want is the window all the way to the endpoint, which could cross multiple connections. The
240   //   first-hop window could be either too big or too small: it's too big if the first hop has
241   //   much higher bandwidth than the full path (causing buffering at the bottleneck), and it's
242   //   too small if the first hop has much lower latency than the full path (causing not enough
243   //   data to be sent to saturate the connection). To handle this, we could either:
244   //   1. Have proxies be aware of streaming, by flagging streaming calls in the RPC protocol. The
245   //      proxies would then handle backpressure at each hop. This seems simple to implement but
246   //      requires base RPC protocol changes and might require thinking carefully about e-ordering
247   //      implications. Also, it only fixes underutilization; it does not fix buffer bloat.
248   //   2. Do our own BBR-like computation, where the client measures the end-to-end latency and
249   //      bandwidth based on the observed sends and returns, and then compute the window based on
250   //      that. This seems complicated, but avoids the need for any changes to the RPC protocol.
251   //      In theory it solves both underutilization and buffer bloat. Note that this approach would
252   //      require the RPC system to use a clock, which feels dirty and adds non-determinism.
253 
254   if (solSndbufUnimplemented) {
255     return RpcFlowController::DEFAULT_WINDOW_SIZE;
256   } else {
257     KJ_IF_MAYBE(bufSize, getStream().getSendBufferSize()) {
258       return *bufSize;
259     } else {
260       solSndbufUnimplemented = true;
261       return RpcFlowController::DEFAULT_WINDOW_SIZE;
262     }
263   }
264 }
265 
getPeerVatId()266 rpc::twoparty::VatId::Reader TwoPartyVatNetwork::getPeerVatId() {
267   return peerVatId.getRoot<rpc::twoparty::VatId>();
268 }
269 
newOutgoingMessage(uint firstSegmentWordSize)270 kj::Own<OutgoingRpcMessage> TwoPartyVatNetwork::newOutgoingMessage(uint firstSegmentWordSize) {
271   return kj::refcounted<OutgoingMessageImpl>(*this, firstSegmentWordSize);
272 }
273 
receiveIncomingMessage()274 kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveIncomingMessage() {
275   return kj::evalLater([this]() -> kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> {
276     KJ_IF_MAYBE(e, readCancelReason) {
277       // A previous write failed; propagate the failure to reads, too.
278       return kj::cp(*e);
279     }
280 
281     kj::Array<kj::AutoCloseFd> fdSpace = nullptr;
282     if(maxFdsPerMessage > 0) {
283       fdSpace = kj::heapArray<kj::AutoCloseFd>(maxFdsPerMessage);
284     }
285     auto promise = readCanceler.wrap(getStream().tryReadMessage(fdSpace, receiveOptions));
286     return promise.then([fdSpace = kj::mv(fdSpace)]
287                         (kj::Maybe<MessageReaderAndFds>&& messageAndFds) mutable
288                       -> kj::Maybe<kj::Own<IncomingRpcMessage>> {
289       KJ_IF_MAYBE(m, messageAndFds) {
290         if (m->fds.size() > 0) {
291           return kj::Own<IncomingRpcMessage>(
292               kj::heap<IncomingMessageImpl>(kj::mv(*m), kj::mv(fdSpace)));
293         } else {
294           return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(m->reader)));
295         }
296       } else {
297         return nullptr;
298       }
299     });
300   });
301 }
302 
shutdown()303 kj::Promise<void> TwoPartyVatNetwork::shutdown() {
304   kj::Promise<void> result = KJ_ASSERT_NONNULL(previousWrite, "already shut down").then([this]() {
305     return getStream().end();
306   });
307   previousWrite = nullptr;
308   return kj::mv(result);
309 }
310 
311 // =======================================================================================
312 
TwoPartyServer(Capability::Client bootstrapInterface)313 TwoPartyServer::TwoPartyServer(Capability::Client bootstrapInterface)
314     : bootstrapInterface(kj::mv(bootstrapInterface)), tasks(*this) {}
315 
316 struct TwoPartyServer::AcceptedConnection {
317   kj::Own<kj::AsyncIoStream> connection;
318   TwoPartyVatNetwork network;
319   RpcSystem<rpc::twoparty::VatId> rpcSystem;
320 
AcceptedConnectioncapnp::TwoPartyServer::AcceptedConnection321   explicit AcceptedConnection(Capability::Client bootstrapInterface,
322                               kj::Own<kj::AsyncIoStream>&& connectionParam)
323       : connection(kj::mv(connectionParam)),
324         network(*connection, rpc::twoparty::Side::SERVER),
325         rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {}
326 
AcceptedConnectioncapnp::TwoPartyServer::AcceptedConnection327   explicit AcceptedConnection(Capability::Client bootstrapInterface,
328                               kj::Own<kj::AsyncCapabilityStream>&& connectionParam,
329                               uint maxFdsPerMessage)
330       : connection(kj::mv(connectionParam)),
331         network(kj::downcast<kj::AsyncCapabilityStream>(*connection),
332                 maxFdsPerMessage, rpc::twoparty::Side::SERVER),
333         rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {}
334 };
335 
accept(kj::Own<kj::AsyncIoStream> && connection)336 void TwoPartyServer::accept(kj::Own<kj::AsyncIoStream>&& connection) {
337   auto connectionState = kj::heap<AcceptedConnection>(bootstrapInterface, kj::mv(connection));
338 
339   // Run the connection until disconnect.
340   auto promise = connectionState->network.onDisconnect();
341   tasks.add(promise.attach(kj::mv(connectionState)));
342 }
343 
accept(kj::Own<kj::AsyncCapabilityStream> && connection,uint maxFdsPerMessage)344 void TwoPartyServer::accept(
345     kj::Own<kj::AsyncCapabilityStream>&& connection, uint maxFdsPerMessage) {
346   auto connectionState = kj::heap<AcceptedConnection>(
347       bootstrapInterface, kj::mv(connection), maxFdsPerMessage);
348 
349   // Run the connection until disconnect.
350   auto promise = connectionState->network.onDisconnect();
351   tasks.add(promise.attach(kj::mv(connectionState)));
352 }
353 
accept(kj::AsyncIoStream & connection)354 kj::Promise<void> TwoPartyServer::accept(kj::AsyncIoStream& connection) {
355   auto connectionState = kj::heap<AcceptedConnection>(bootstrapInterface,
356       kj::Own<kj::AsyncIoStream>(&connection, kj::NullDisposer::instance));
357 
358   // Run the connection until disconnect.
359   auto promise = connectionState->network.onDisconnect();
360   return promise.attach(kj::mv(connectionState));
361 }
362 
accept(kj::AsyncCapabilityStream & connection,uint maxFdsPerMessage)363 kj::Promise<void> TwoPartyServer::accept(
364     kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage) {
365   auto connectionState = kj::heap<AcceptedConnection>(bootstrapInterface,
366       kj::Own<kj::AsyncCapabilityStream>(&connection, kj::NullDisposer::instance),
367       maxFdsPerMessage);
368 
369   // Run the connection until disconnect.
370   auto promise = connectionState->network.onDisconnect();
371   return promise.attach(kj::mv(connectionState));
372 }
373 
listen(kj::ConnectionReceiver & listener)374 kj::Promise<void> TwoPartyServer::listen(kj::ConnectionReceiver& listener) {
375   return listener.accept()
376       .then([this,&listener](kj::Own<kj::AsyncIoStream>&& connection) mutable {
377     accept(kj::mv(connection));
378     return listen(listener);
379   });
380 }
381 
listenCapStreamReceiver(kj::ConnectionReceiver & listener,uint maxFdsPerMessage)382 kj::Promise<void> TwoPartyServer::listenCapStreamReceiver(
383       kj::ConnectionReceiver& listener, uint maxFdsPerMessage) {
384   return listener.accept()
385       .then([this,&listener,maxFdsPerMessage](kj::Own<kj::AsyncIoStream>&& connection) mutable {
386     accept(connection.downcast<kj::AsyncCapabilityStream>(), maxFdsPerMessage);
387     return listenCapStreamReceiver(listener, maxFdsPerMessage);
388   });
389 }
390 
taskFailed(kj::Exception && exception)391 void TwoPartyServer::taskFailed(kj::Exception&& exception) {
392   KJ_LOG(ERROR, exception);
393 }
394 
TwoPartyClient(kj::AsyncIoStream & connection)395 TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection)
396     : network(connection, rpc::twoparty::Side::CLIENT),
397       rpcSystem(makeRpcClient(network)) {}
398 
399 
TwoPartyClient(kj::AsyncCapabilityStream & connection,uint maxFdsPerMessage)400 TwoPartyClient::TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage)
401     : network(connection, maxFdsPerMessage, rpc::twoparty::Side::CLIENT),
402       rpcSystem(makeRpcClient(network)) {}
403 
TwoPartyClient(kj::AsyncIoStream & connection,Capability::Client bootstrapInterface,rpc::twoparty::Side side)404 TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection,
405                                Capability::Client bootstrapInterface,
406                                rpc::twoparty::Side side)
407     : network(connection, side),
408       rpcSystem(network, bootstrapInterface) {}
409 
TwoPartyClient(kj::AsyncCapabilityStream & connection,uint maxFdsPerMessage,Capability::Client bootstrapInterface,rpc::twoparty::Side side)410 TwoPartyClient::TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage,
411                                Capability::Client bootstrapInterface,
412                                rpc::twoparty::Side side)
413     : network(connection, maxFdsPerMessage, side),
414       rpcSystem(network, bootstrapInterface) {}
415 
bootstrap()416 Capability::Client TwoPartyClient::bootstrap() {
417   capnp::word scratch[4];
418   memset(&scratch, 0, sizeof(scratch));
419   capnp::MallocMessageBuilder message(scratch);
420   auto vatId = message.getRoot<rpc::twoparty::VatId>();
421   vatId.setSide(network.getSide() == rpc::twoparty::Side::CLIENT
422                 ? rpc::twoparty::Side::SERVER
423                 : rpc::twoparty::Side::CLIENT);
424   return rpcSystem.bootstrap(vatId);
425 }
426 
setTraceEncoder(kj::Function<kj::String (const kj::Exception &)> func)427 void TwoPartyClient::setTraceEncoder(kj::Function<kj::String(const kj::Exception&)> func) {
428   rpcSystem.setTraceEncoder(kj::mv(func));
429 }
430 
431 }  // namespace capnp
432