1 // Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
2 // Licensed under the MIT License:
3 //
4 // Permission is hereby granted, free of charge, to any person obtaining a copy
5 // of this software and associated documentation files (the "Software"), to deal
6 // in the Software without restriction, including without limitation the rights
7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 // copies of the Software, and to permit persons to whom the Software is
9 // furnished to do so, subject to the following conditions:
10 //
11 // The above copyright notice and this permission notice shall be included in
12 // all copies or substantial portions of the Software.
13 //
14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 // THE SOFTWARE.
21
22 #include "rpc-twoparty.h"
23 #include "serialize-async.h"
24 #include <kj/debug.h>
25 #include <kj/io.h>
26
27 namespace capnp {
28
TwoPartyVatNetwork(kj::OneOf<MessageStream *,kj::Own<MessageStream>> && stream,uint maxFdsPerMessage,rpc::twoparty::Side side,ReaderOptions receiveOptions,const kj::MonotonicClock & clock)29 TwoPartyVatNetwork::TwoPartyVatNetwork(
30 kj::OneOf<MessageStream*, kj::Own<MessageStream>>&& stream,
31 uint maxFdsPerMessage,
32 rpc::twoparty::Side side,
33 ReaderOptions receiveOptions,
34 const kj::MonotonicClock& clock)
35
36 : stream(kj::mv(stream)),
37 maxFdsPerMessage(maxFdsPerMessage),
38 side(side),
39 peerVatId(4),
40 receiveOptions(receiveOptions),
41 previousWrite(kj::READY_NOW),
42 clock(clock),
43 currentOutgoingMessageSendTime(clock.now()) {
44 peerVatId.initRoot<rpc::twoparty::VatId>().setSide(
45 side == rpc::twoparty::Side::CLIENT ? rpc::twoparty::Side::SERVER
46 : rpc::twoparty::Side::CLIENT);
47
48 auto paf = kj::newPromiseAndFulfiller<void>();
49 disconnectPromise = paf.promise.fork();
50 disconnectFulfiller.fulfiller = kj::mv(paf.fulfiller);
51 }
52
TwoPartyVatNetwork(capnp::MessageStream & stream,rpc::twoparty::Side side,ReaderOptions receiveOptions,const kj::MonotonicClock & clock)53 TwoPartyVatNetwork::TwoPartyVatNetwork(capnp::MessageStream& stream,
54 rpc::twoparty::Side side, ReaderOptions receiveOptions,
55 const kj::MonotonicClock& clock)
56 : TwoPartyVatNetwork(stream, 0, side, receiveOptions, clock) {}
57
TwoPartyVatNetwork(capnp::MessageStream & stream,uint maxFdsPerMessage,rpc::twoparty::Side side,ReaderOptions receiveOptions,const kj::MonotonicClock & clock)58 TwoPartyVatNetwork::TwoPartyVatNetwork(
59 capnp::MessageStream& stream,
60 uint maxFdsPerMessage,
61 rpc::twoparty::Side side,
62 ReaderOptions receiveOptions,
63 const kj::MonotonicClock& clock)
64 : TwoPartyVatNetwork(&stream, maxFdsPerMessage, side, receiveOptions, clock) {}
65
TwoPartyVatNetwork(kj::AsyncIoStream & stream,rpc::twoparty::Side side,ReaderOptions receiveOptions,const kj::MonotonicClock & clock)66 TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side,
67 ReaderOptions receiveOptions,
68 const kj::MonotonicClock& clock)
69 : TwoPartyVatNetwork(kj::Own<MessageStream>(kj::heap<AsyncIoMessageStream>(stream)),
70 0, side, receiveOptions, clock) {}
71
TwoPartyVatNetwork(kj::AsyncCapabilityStream & stream,uint maxFdsPerMessage,rpc::twoparty::Side side,ReaderOptions receiveOptions,const kj::MonotonicClock & clock)72 TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncCapabilityStream& stream, uint maxFdsPerMessage,
73 rpc::twoparty::Side side, ReaderOptions receiveOptions,
74 const kj::MonotonicClock& clock)
75 : TwoPartyVatNetwork(kj::Own<MessageStream>(kj::heap<AsyncCapabilityMessageStream>(stream)),
76 maxFdsPerMessage, side, receiveOptions, clock) {}
77
getStream()78 MessageStream& TwoPartyVatNetwork::getStream() {
79 KJ_SWITCH_ONEOF(stream) {
80 KJ_CASE_ONEOF(s, MessageStream*) {
81 return *s;
82 }
83 KJ_CASE_ONEOF(s, kj::Own<MessageStream>) {
84 return *s;
85 }
86 }
87 KJ_UNREACHABLE;
88 }
89
disposeImpl(void * pointer) const90 void TwoPartyVatNetwork::FulfillerDisposer::disposeImpl(void* pointer) const {
91 if (--refcount == 0) {
92 fulfiller->fulfill();
93 }
94 }
95
asConnection()96 kj::Own<TwoPartyVatNetworkBase::Connection> TwoPartyVatNetwork::asConnection() {
97 ++disconnectFulfiller.refcount;
98 return kj::Own<TwoPartyVatNetworkBase::Connection>(this, disconnectFulfiller);
99 }
100
connect(rpc::twoparty::VatId::Reader ref)101 kj::Maybe<kj::Own<TwoPartyVatNetworkBase::Connection>> TwoPartyVatNetwork::connect(
102 rpc::twoparty::VatId::Reader ref) {
103 if (ref.getSide() == side) {
104 return nullptr;
105 } else {
106 return asConnection();
107 }
108 }
109
accept()110 kj::Promise<kj::Own<TwoPartyVatNetworkBase::Connection>> TwoPartyVatNetwork::accept() {
111 if (side == rpc::twoparty::Side::SERVER && !accepted) {
112 accepted = true;
113 return asConnection();
114 } else {
115 // Create a promise that will never be fulfilled.
116 auto paf = kj::newPromiseAndFulfiller<kj::Own<TwoPartyVatNetworkBase::Connection>>();
117 acceptFulfiller = kj::mv(paf.fulfiller);
118 return kj::mv(paf.promise);
119 }
120 }
121
122 class TwoPartyVatNetwork::OutgoingMessageImpl final
123 : public OutgoingRpcMessage, public kj::Refcounted {
124 public:
OutgoingMessageImpl(TwoPartyVatNetwork & network,uint firstSegmentWordSize)125 OutgoingMessageImpl(TwoPartyVatNetwork& network, uint firstSegmentWordSize)
126 : network(network),
127 message(firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize) {}
128
getBody()129 AnyPointer::Builder getBody() override {
130 return message.getRoot<AnyPointer>();
131 }
132
setFds(kj::Array<int> fds)133 void setFds(kj::Array<int> fds) override {
134 if (network.maxFdsPerMessage > 0) {
135 this->fds = kj::mv(fds);
136 }
137 }
138
send()139 void send() override {
140 size_t size = 0;
141 for (auto& segment: message.getSegmentsForOutput()) {
142 size += segment.size();
143 }
144 KJ_REQUIRE(size < network.receiveOptions.traversalLimitInWords, size,
145 "Trying to send Cap'n Proto message larger than our single-message size limit. The "
146 "other side probably won't accept it (assuming its traversalLimitInWords matches "
147 "ours) and would abort the connection, so I won't send it.") {
148 return;
149 }
150
151 network.currentQueueSize += size * sizeof(capnp::word);
152 ++network.currentQueueCount;
153 auto deferredSizeUpdate = kj::defer([&network = network, size]() mutable {
154 network.currentQueueSize -= size * sizeof(capnp::word);
155 --network.currentQueueCount;
156 });
157
158 auto sendTime = network.clock.now();
159 network.previousWrite = KJ_ASSERT_NONNULL(network.previousWrite, "already shut down")
160 .then([this, sendTime]() {
161 return kj::evalNow([&]() {
162 network.currentOutgoingMessageSendTime = sendTime;
163 return network.getStream().writeMessage(fds, message);
164 }).catch_([this](kj::Exception&& e) {
165 // Since no one checks write failures, we need to propagate them into read failures,
166 // otherwise we might get stuck sending all messages into a black hole and wondering why
167 // the peer never replies.
168 network.readCancelReason = kj::cp(e);
169 if (!network.readCanceler.isEmpty()) {
170 network.readCanceler.cancel(kj::cp(e));
171 }
172 kj::throwRecoverableException(kj::mv(e));
173 });
174 }).attach(kj::addRef(*this), kj::mv(deferredSizeUpdate))
175 // Note that it's important that the eagerlyEvaluate() come *after* the attach() because
176 // otherwise the message (and any capabilities in it) will not be released until a new
177 // message is written! (Kenton once spent all afternoon tracking this down...)
178 .eagerlyEvaluate(nullptr);
179 }
180
sizeInWords()181 size_t sizeInWords() override {
182 return message.sizeInWords();
183 }
184
185 private:
186 TwoPartyVatNetwork& network;
187 MallocMessageBuilder message;
188 kj::Array<int> fds;
189 };
190
getOutgoingMessageWaitTime()191 kj::Duration TwoPartyVatNetwork::getOutgoingMessageWaitTime() {
192 if (currentQueueCount > 0) {
193 return clock.now() - currentOutgoingMessageSendTime;
194 } else {
195 return 0 * kj::SECONDS;
196 }
197 }
198
199 class TwoPartyVatNetwork::IncomingMessageImpl final: public IncomingRpcMessage {
200 public:
IncomingMessageImpl(kj::Own<MessageReader> message)201 IncomingMessageImpl(kj::Own<MessageReader> message): message(kj::mv(message)) {}
202
IncomingMessageImpl(MessageReaderAndFds init,kj::Array<kj::AutoCloseFd> fdSpace)203 IncomingMessageImpl(MessageReaderAndFds init, kj::Array<kj::AutoCloseFd> fdSpace)
204 : message(kj::mv(init.reader)),
205 fdSpace(kj::mv(fdSpace)),
206 fds(init.fds) {
207 KJ_DASSERT(this->fds.begin() == this->fdSpace.begin());
208 }
209
getBody()210 AnyPointer::Reader getBody() override {
211 return message->getRoot<AnyPointer>();
212 }
213
getAttachedFds()214 kj::ArrayPtr<kj::AutoCloseFd> getAttachedFds() override {
215 return fds;
216 }
217
sizeInWords()218 size_t sizeInWords() override {
219 return message->sizeInWords();
220 }
221
222 private:
223 kj::Own<MessageReader> message;
224 kj::Array<kj::AutoCloseFd> fdSpace;
225 kj::ArrayPtr<kj::AutoCloseFd> fds;
226 };
227
newStream()228 kj::Own<RpcFlowController> TwoPartyVatNetwork::newStream() {
229 return RpcFlowController::newVariableWindowController(*this);
230 }
231
getWindow()232 size_t TwoPartyVatNetwork::getWindow() {
233 // The socket's send buffer size -- as returned by getsockopt(SO_SNDBUF) -- tells us how much
234 // data the kernel itself is willing to buffer. The kernel will increase the send buffer size if
235 // needed to fill the connection's congestion window. So we can cheat and use it as our stream
236 // window, too, to make sure we saturate said congestion window.
237 //
238 // TODO(perf): Unfortunately, this hack breaks down in the presence of proxying. What we really
239 // want is the window all the way to the endpoint, which could cross multiple connections. The
240 // first-hop window could be either too big or too small: it's too big if the first hop has
241 // much higher bandwidth than the full path (causing buffering at the bottleneck), and it's
242 // too small if the first hop has much lower latency than the full path (causing not enough
243 // data to be sent to saturate the connection). To handle this, we could either:
244 // 1. Have proxies be aware of streaming, by flagging streaming calls in the RPC protocol. The
245 // proxies would then handle backpressure at each hop. This seems simple to implement but
246 // requires base RPC protocol changes and might require thinking carefully about e-ordering
247 // implications. Also, it only fixes underutilization; it does not fix buffer bloat.
248 // 2. Do our own BBR-like computation, where the client measures the end-to-end latency and
249 // bandwidth based on the observed sends and returns, and then compute the window based on
250 // that. This seems complicated, but avoids the need for any changes to the RPC protocol.
251 // In theory it solves both underutilization and buffer bloat. Note that this approach would
252 // require the RPC system to use a clock, which feels dirty and adds non-determinism.
253
254 if (solSndbufUnimplemented) {
255 return RpcFlowController::DEFAULT_WINDOW_SIZE;
256 } else {
257 KJ_IF_MAYBE(bufSize, getStream().getSendBufferSize()) {
258 return *bufSize;
259 } else {
260 solSndbufUnimplemented = true;
261 return RpcFlowController::DEFAULT_WINDOW_SIZE;
262 }
263 }
264 }
265
getPeerVatId()266 rpc::twoparty::VatId::Reader TwoPartyVatNetwork::getPeerVatId() {
267 return peerVatId.getRoot<rpc::twoparty::VatId>();
268 }
269
newOutgoingMessage(uint firstSegmentWordSize)270 kj::Own<OutgoingRpcMessage> TwoPartyVatNetwork::newOutgoingMessage(uint firstSegmentWordSize) {
271 return kj::refcounted<OutgoingMessageImpl>(*this, firstSegmentWordSize);
272 }
273
receiveIncomingMessage()274 kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveIncomingMessage() {
275 return kj::evalLater([this]() -> kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> {
276 KJ_IF_MAYBE(e, readCancelReason) {
277 // A previous write failed; propagate the failure to reads, too.
278 return kj::cp(*e);
279 }
280
281 kj::Array<kj::AutoCloseFd> fdSpace = nullptr;
282 if(maxFdsPerMessage > 0) {
283 fdSpace = kj::heapArray<kj::AutoCloseFd>(maxFdsPerMessage);
284 }
285 auto promise = readCanceler.wrap(getStream().tryReadMessage(fdSpace, receiveOptions));
286 return promise.then([fdSpace = kj::mv(fdSpace)]
287 (kj::Maybe<MessageReaderAndFds>&& messageAndFds) mutable
288 -> kj::Maybe<kj::Own<IncomingRpcMessage>> {
289 KJ_IF_MAYBE(m, messageAndFds) {
290 if (m->fds.size() > 0) {
291 return kj::Own<IncomingRpcMessage>(
292 kj::heap<IncomingMessageImpl>(kj::mv(*m), kj::mv(fdSpace)));
293 } else {
294 return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(m->reader)));
295 }
296 } else {
297 return nullptr;
298 }
299 });
300 });
301 }
302
shutdown()303 kj::Promise<void> TwoPartyVatNetwork::shutdown() {
304 kj::Promise<void> result = KJ_ASSERT_NONNULL(previousWrite, "already shut down").then([this]() {
305 return getStream().end();
306 });
307 previousWrite = nullptr;
308 return kj::mv(result);
309 }
310
311 // =======================================================================================
312
TwoPartyServer(Capability::Client bootstrapInterface)313 TwoPartyServer::TwoPartyServer(Capability::Client bootstrapInterface)
314 : bootstrapInterface(kj::mv(bootstrapInterface)), tasks(*this) {}
315
316 struct TwoPartyServer::AcceptedConnection {
317 kj::Own<kj::AsyncIoStream> connection;
318 TwoPartyVatNetwork network;
319 RpcSystem<rpc::twoparty::VatId> rpcSystem;
320
AcceptedConnectioncapnp::TwoPartyServer::AcceptedConnection321 explicit AcceptedConnection(Capability::Client bootstrapInterface,
322 kj::Own<kj::AsyncIoStream>&& connectionParam)
323 : connection(kj::mv(connectionParam)),
324 network(*connection, rpc::twoparty::Side::SERVER),
325 rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {}
326
AcceptedConnectioncapnp::TwoPartyServer::AcceptedConnection327 explicit AcceptedConnection(Capability::Client bootstrapInterface,
328 kj::Own<kj::AsyncCapabilityStream>&& connectionParam,
329 uint maxFdsPerMessage)
330 : connection(kj::mv(connectionParam)),
331 network(kj::downcast<kj::AsyncCapabilityStream>(*connection),
332 maxFdsPerMessage, rpc::twoparty::Side::SERVER),
333 rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {}
334 };
335
accept(kj::Own<kj::AsyncIoStream> && connection)336 void TwoPartyServer::accept(kj::Own<kj::AsyncIoStream>&& connection) {
337 auto connectionState = kj::heap<AcceptedConnection>(bootstrapInterface, kj::mv(connection));
338
339 // Run the connection until disconnect.
340 auto promise = connectionState->network.onDisconnect();
341 tasks.add(promise.attach(kj::mv(connectionState)));
342 }
343
accept(kj::Own<kj::AsyncCapabilityStream> && connection,uint maxFdsPerMessage)344 void TwoPartyServer::accept(
345 kj::Own<kj::AsyncCapabilityStream>&& connection, uint maxFdsPerMessage) {
346 auto connectionState = kj::heap<AcceptedConnection>(
347 bootstrapInterface, kj::mv(connection), maxFdsPerMessage);
348
349 // Run the connection until disconnect.
350 auto promise = connectionState->network.onDisconnect();
351 tasks.add(promise.attach(kj::mv(connectionState)));
352 }
353
accept(kj::AsyncIoStream & connection)354 kj::Promise<void> TwoPartyServer::accept(kj::AsyncIoStream& connection) {
355 auto connectionState = kj::heap<AcceptedConnection>(bootstrapInterface,
356 kj::Own<kj::AsyncIoStream>(&connection, kj::NullDisposer::instance));
357
358 // Run the connection until disconnect.
359 auto promise = connectionState->network.onDisconnect();
360 return promise.attach(kj::mv(connectionState));
361 }
362
accept(kj::AsyncCapabilityStream & connection,uint maxFdsPerMessage)363 kj::Promise<void> TwoPartyServer::accept(
364 kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage) {
365 auto connectionState = kj::heap<AcceptedConnection>(bootstrapInterface,
366 kj::Own<kj::AsyncCapabilityStream>(&connection, kj::NullDisposer::instance),
367 maxFdsPerMessage);
368
369 // Run the connection until disconnect.
370 auto promise = connectionState->network.onDisconnect();
371 return promise.attach(kj::mv(connectionState));
372 }
373
listen(kj::ConnectionReceiver & listener)374 kj::Promise<void> TwoPartyServer::listen(kj::ConnectionReceiver& listener) {
375 return listener.accept()
376 .then([this,&listener](kj::Own<kj::AsyncIoStream>&& connection) mutable {
377 accept(kj::mv(connection));
378 return listen(listener);
379 });
380 }
381
listenCapStreamReceiver(kj::ConnectionReceiver & listener,uint maxFdsPerMessage)382 kj::Promise<void> TwoPartyServer::listenCapStreamReceiver(
383 kj::ConnectionReceiver& listener, uint maxFdsPerMessage) {
384 return listener.accept()
385 .then([this,&listener,maxFdsPerMessage](kj::Own<kj::AsyncIoStream>&& connection) mutable {
386 accept(connection.downcast<kj::AsyncCapabilityStream>(), maxFdsPerMessage);
387 return listenCapStreamReceiver(listener, maxFdsPerMessage);
388 });
389 }
390
taskFailed(kj::Exception && exception)391 void TwoPartyServer::taskFailed(kj::Exception&& exception) {
392 KJ_LOG(ERROR, exception);
393 }
394
TwoPartyClient(kj::AsyncIoStream & connection)395 TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection)
396 : network(connection, rpc::twoparty::Side::CLIENT),
397 rpcSystem(makeRpcClient(network)) {}
398
399
TwoPartyClient(kj::AsyncCapabilityStream & connection,uint maxFdsPerMessage)400 TwoPartyClient::TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage)
401 : network(connection, maxFdsPerMessage, rpc::twoparty::Side::CLIENT),
402 rpcSystem(makeRpcClient(network)) {}
403
TwoPartyClient(kj::AsyncIoStream & connection,Capability::Client bootstrapInterface,rpc::twoparty::Side side)404 TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection,
405 Capability::Client bootstrapInterface,
406 rpc::twoparty::Side side)
407 : network(connection, side),
408 rpcSystem(network, bootstrapInterface) {}
409
TwoPartyClient(kj::AsyncCapabilityStream & connection,uint maxFdsPerMessage,Capability::Client bootstrapInterface,rpc::twoparty::Side side)410 TwoPartyClient::TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage,
411 Capability::Client bootstrapInterface,
412 rpc::twoparty::Side side)
413 : network(connection, maxFdsPerMessage, side),
414 rpcSystem(network, bootstrapInterface) {}
415
bootstrap()416 Capability::Client TwoPartyClient::bootstrap() {
417 capnp::word scratch[4];
418 memset(&scratch, 0, sizeof(scratch));
419 capnp::MallocMessageBuilder message(scratch);
420 auto vatId = message.getRoot<rpc::twoparty::VatId>();
421 vatId.setSide(network.getSide() == rpc::twoparty::Side::CLIENT
422 ? rpc::twoparty::Side::SERVER
423 : rpc::twoparty::Side::CLIENT);
424 return rpcSystem.bootstrap(vatId);
425 }
426
setTraceEncoder(kj::Function<kj::String (const kj::Exception &)> func)427 void TwoPartyClient::setTraceEncoder(kj::Function<kj::String(const kj::Exception&)> func) {
428 rpcSystem.setTraceEncoder(kj::mv(func));
429 }
430
431 } // namespace capnp
432