1 // Copyright (c) 2019 Cloudflare, Inc. and contributors
2 // Licensed under the MIT License:
3 //
4 // Permission is hereby granted, free of charge, to any person obtaining a copy
5 // of this software and associated documentation files (the "Software"), to deal
6 // in the Software without restriction, including without limitation the rights
7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 // copies of the Software, and to permit persons to whom the Software is
9 // furnished to do so, subject to the following conditions:
10 //
11 // The above copyright notice and this permission notice shall be included in
12 // all copies or substantial portions of the Software.
13 //
14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 // THE SOFTWARE.
21
22 #include "byte-stream.h"
23 #include <kj/one-of.h>
24 #include <kj/debug.h>
25
26 namespace capnp {
27
28 const uint MAX_BYTES_PER_WRITE = 1 << 16;
29
30 class ByteStreamFactory::StreamServerBase: public capnp::ByteStream::Server {
31 public:
32 virtual void returnStream(uint64_t written) = 0;
33 // Called after the StreamServerBase's internal kj::AsyncOutputStream has been borrowed, to
34 // indicate that the borrower is done.
35 //
36 // A stream becomes borrowed either when getShortestPath() returns a BorrowedStream, or when
37 // a SubstreamImpl is constructed wrapping an existing stream.
38
39 struct BorrowedStream {
40 // Represents permission to use the StreamServerBase's inner AsyncOutputStream directly, up
41 // to some limit of bytes written.
42
43 StreamServerBase& lender;
44 kj::AsyncOutputStream& stream;
45 uint64_t limit;
46 };
47
48 typedef kj::OneOf<kj::Promise<void>, capnp::ByteStream::Client*, BorrowedStream> ShortestPath;
49
50 virtual ShortestPath getShortestPath() = 0;
51 // Called by KjToCapnpStreamAdapter when it has determined that its inner ByteStream::Client
52 // actually points back to a StreamServerBase in the same process created by the same
53 // ByteStreamFactory. Returns the best shortened path to use, or a promise that resolves when the
54 // shortest path is known.
55
56 virtual void directEnd() = 0;
57 // Called by KjToCapnpStreamAdapter's destructor when it has determined that its inner
58 // ByteStream::Client actually points back to a StreamServerBase in the same process created by
59 // the same ByteStreamFactory. Since destruction of a KJ stream signals EOF, we need to propagate
60 // that by destroying our underlying stream.
61 // TODO(cleanup): When KJ streams evolve an end() method, this can go away.
62 };
63
64 class ByteStreamFactory::SubstreamImpl final: public StreamServerBase {
65 public:
SubstreamImpl(ByteStreamFactory & factory,StreamServerBase & parent,capnp::ByteStream::Client ownParent,kj::AsyncOutputStream & stream,capnp::ByteStream::SubstreamCallback::Client callback,uint64_t limit,kj::PromiseFulfillerPair<void> paf=kj::newPromiseAndFulfiller<void> ())66 SubstreamImpl(ByteStreamFactory& factory,
67 StreamServerBase& parent,
68 capnp::ByteStream::Client ownParent,
69 kj::AsyncOutputStream& stream,
70 capnp::ByteStream::SubstreamCallback::Client callback,
71 uint64_t limit,
72 kj::PromiseFulfillerPair<void> paf = kj::newPromiseAndFulfiller<void>())
73 : factory(factory),
74 state(Streaming {parent, kj::mv(ownParent), stream, kj::mv(callback)}),
75 limit(limit),
76 resolveFulfiller(kj::mv(paf.fulfiller)),
77 resolvePromise(paf.promise.fork()) {}
78
79 // ---------------------------------------------------------------------------
80 // implements StreamServerBase
81
returnStream(uint64_t written)82 void returnStream(uint64_t written) override {
83 completed += written;
84 KJ_ASSERT(completed <= limit);
85 auto borrowed = kj::mv(state.get<Borrowed>());
86 state = kj::mv(borrowed.originalState);
87
88 if (completed == limit) {
89 limitReached();
90 }
91 }
92
getShortestPath()93 ShortestPath getShortestPath() override {
94 KJ_SWITCH_ONEOF(state) {
95 KJ_CASE_ONEOF(redirected, Redirected) {
96 return &redirected.replacement;
97 }
98 KJ_CASE_ONEOF(e, Ended) {
99 KJ_FAIL_REQUIRE("already called end()");
100 }
101 KJ_CASE_ONEOF(b, Borrowed) {
102 KJ_FAIL_REQUIRE("can't call other methods while substream is active");
103 }
104 KJ_CASE_ONEOF(streaming, Streaming) {
105 auto& stream = streaming.stream;
106 auto oldState = kj::mv(streaming);
107 state = Borrowed { kj::mv(oldState) };
108 return BorrowedStream { *this, stream, limit - completed };
109 }
110 }
111 KJ_UNREACHABLE;
112 }
113
directEnd()114 void directEnd() override {
115 KJ_SWITCH_ONEOF(state) {
116 KJ_CASE_ONEOF(redirected, Redirected) {
117 // Ugh I guess we need to send a real end() request here.
118 redirected.replacement.endRequest(MessageSize {2, 0}).send().detach([](kj::Exception&&){});
119 }
120 KJ_CASE_ONEOF(e, Ended) {
121 // whatever
122 }
123 KJ_CASE_ONEOF(b, Borrowed) {
124 // ... whatever.
125 }
126 KJ_CASE_ONEOF(streaming, Streaming) {
127 auto req = streaming.callback.endedRequest(MessageSize {4, 0});
128 req.setByteCount(completed);
129 req.send().detach([](kj::Exception&&){});
130 streaming.parent.returnStream(completed);
131 state = Ended();
132 }
133 }
134 }
135
136 // ---------------------------------------------------------------------------
137 // implements ByteStream::Server RPC interface
138
shortenPath()139 kj::Maybe<kj::Promise<Capability::Client>> shortenPath() override {
140 return resolvePromise.addBranch()
141 .then([this]() -> Capability::Client {
142 return state.get<Redirected>().replacement;
143 });
144 }
145
write(WriteContext context)146 kj::Promise<void> write(WriteContext context) override {
147 auto params = context.getParams();
148 auto data = params.getBytes();
149
150 KJ_SWITCH_ONEOF(state) {
151 KJ_CASE_ONEOF(redirected, Redirected) {
152 auto req = redirected.replacement.writeRequest(params.totalSize());
153 req.setBytes(data);
154 return req.send();
155 }
156 KJ_CASE_ONEOF(e, Ended) {
157 KJ_FAIL_REQUIRE("already called end()");
158 }
159 KJ_CASE_ONEOF(b, Borrowed) {
160 KJ_FAIL_REQUIRE("can't call other methods while stream is borrowed");
161 }
162 KJ_CASE_ONEOF(streaming, Streaming) {
163 if (completed + data.size() < limit) {
164 completed += data.size();
165 return streaming.stream.write(data.begin(), data.size());
166 } else {
167 // This write passes the limit.
168 uint64_t remainder = limit - completed;
169 auto leftover = data.slice(remainder, data.size());
170 return streaming.stream.write(data.begin(), remainder)
171 .then([this, leftover]() -> kj::Promise<void> {
172 completed = limit;
173 limitReached();
174
175 if (leftover.size() > 0) {
176 // Need to forward the leftover bytes to the next stream.
177 auto req = state.get<Redirected>().replacement.writeRequest(
178 MessageSize { 4 + leftover.size() / sizeof(capnp::word), 0 });
179 req.setBytes(leftover);
180 return req.send();
181 } else {
182 return kj::READY_NOW;
183 }
184 });
185 }
186 }
187 }
188 KJ_UNREACHABLE;
189 }
190
end(EndContext context)191 kj::Promise<void> end(EndContext context) override {
192 KJ_SWITCH_ONEOF(state) {
193 KJ_CASE_ONEOF(redirected, Redirected) {
194 return context.tailCall(redirected.replacement.endRequest(MessageSize {2,0}));
195 }
196 KJ_CASE_ONEOF(e, Ended) {
197 KJ_FAIL_REQUIRE("already called end()");
198 }
199 KJ_CASE_ONEOF(b, Borrowed) {
200 KJ_FAIL_REQUIRE("can't call other methods while stream is borrowed");
201 }
202 KJ_CASE_ONEOF(streaming, Streaming) {
203 auto req = streaming.callback.endedRequest(MessageSize {4, 0});
204 req.setByteCount(completed);
205 auto result = req.send().ignoreResult();
206 streaming.parent.returnStream(completed);
207 state = Ended();
208 return result;
209 }
210 }
211 KJ_UNREACHABLE;
212 }
213
getSubstream(GetSubstreamContext context)214 kj::Promise<void> getSubstream(GetSubstreamContext context) override {
215 KJ_SWITCH_ONEOF(state) {
216 KJ_CASE_ONEOF(redirected, Redirected) {
217 auto params = context.getParams();
218 auto req = redirected.replacement.getSubstreamRequest(params.totalSize());
219 req.setCallback(params.getCallback());
220 req.setLimit(params.getLimit());
221 return context.tailCall(kj::mv(req));
222 }
223 KJ_CASE_ONEOF(e, Ended) {
224 KJ_FAIL_REQUIRE("already called end()");
225 }
226 KJ_CASE_ONEOF(b, Borrowed) {
227 KJ_FAIL_REQUIRE("can't call other methods while stream is borrowed");
228 }
229 KJ_CASE_ONEOF(streaming, Streaming) {
230 auto params = context.getParams();
231 auto callback = params.getCallback();
232 auto limit = params.getLimit();
233 context.releaseParams();
234 auto results = context.getResults(MessageSize { 2, 1 });
235 results.setSubstream(factory.streamSet.add(kj::heap<SubstreamImpl>(
236 factory, *this, thisCap(), streaming.stream, kj::mv(callback), kj::mv(limit))));
237 state = Borrowed { kj::mv(streaming) };
238 return kj::READY_NOW;
239 }
240 }
241 KJ_UNREACHABLE;
242 }
243
244 private:
245 ByteStreamFactory& factory;
246
247 struct Streaming {
248 StreamServerBase& parent;
249 capnp::ByteStream::Client ownParent;
250 kj::AsyncOutputStream& stream;
251 capnp::ByteStream::SubstreamCallback::Client callback;
252 };
253 struct Borrowed {
254 Streaming originalState;
255 };
256 struct Redirected {
257 capnp::ByteStream::Client replacement;
258 };
259 struct Ended {};
260
261 kj::OneOf<Streaming, Borrowed, Redirected, Ended> state;
262
263 uint64_t limit;
264 uint64_t completed = 0;
265
266 kj::Own<kj::PromiseFulfiller<void>> resolveFulfiller;
267 kj::ForkedPromise<void> resolvePromise;
268
limitReached()269 void limitReached() {
270 auto& streaming = state.get<Streaming>();
271 auto next = streaming.callback.reachedLimitRequest(capnp::MessageSize {2,0})
272 .send().getNext();
273
274 // Set the next stream as our replacement.
275 streaming.parent.returnStream(limit);
276 state = Redirected { kj::mv(next) };
277 resolveFulfiller->fulfill();
278 }
279 };
280
281 // =======================================================================================
282
283 class ByteStreamFactory::CapnpToKjStreamAdapter final: public StreamServerBase {
284 // Implements Cap'n Proto ByteStream as a wrapper around a KJ stream.
285
286 class SubstreamCallbackImpl;
287
288 public:
289 class PathProber;
290
CapnpToKjStreamAdapter(ByteStreamFactory & factory,kj::Own<kj::AsyncOutputStream> inner)291 CapnpToKjStreamAdapter(ByteStreamFactory& factory,
292 kj::Own<kj::AsyncOutputStream> inner)
293 : factory(factory),
294 state(kj::heap<PathProber>(*this, kj::mv(inner))) {
295 state.get<kj::Own<PathProber>>()->startProbing();
296 }
297
CapnpToKjStreamAdapter(ByteStreamFactory & factory,kj::Own<PathProber> pathProber)298 CapnpToKjStreamAdapter(ByteStreamFactory& factory,
299 kj::Own<PathProber> pathProber)
300 : factory(factory),
301 state(kj::mv(pathProber)) {
302 state.get<kj::Own<PathProber>>()->setNewParent(*this);
303 }
304
305 // ---------------------------------------------------------------------------
306 // implements StreamServerBase
307
returnStream(uint64_t written)308 void returnStream(uint64_t written) override {
309 auto stream = kj::mv(state.get<Borrowed>().stream);
310 state = kj::mv(stream);
311 }
312
getShortestPath()313 ShortestPath getShortestPath() override {
314 // Called by KjToCapnpStreamAdapter when it has determined that its inner ByteStream::Client
315 // actually points back to a CapnpToKjStreamAdapter in the same process. Returns the best
316 // shortened path to use, or a promise that resolves when the shortest path is known.
317
318 KJ_SWITCH_ONEOF(state) {
319 KJ_CASE_ONEOF(prober, kj::Own<PathProber>) {
320 return prober->whenReady();
321 }
322 KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) {
323 auto& streamRef = *kjStream;
324 state = Borrowed { kj::mv(kjStream) };
325 return StreamServerBase::BorrowedStream { *this, streamRef, kj::maxValue };
326 }
327 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) {
328 return &capnpStream;
329 }
330 KJ_CASE_ONEOF(b, Borrowed) {
331 KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; }
332 return kj::Promise<void>(kj::READY_NOW);
333 }
334 KJ_CASE_ONEOF(e, Ended) {
335 KJ_FAIL_REQUIRE("already ended") { break; }
336 return kj::Promise<void>(kj::READY_NOW);
337 }
338 }
339 KJ_UNREACHABLE;
340 }
341
directEnd()342 void directEnd() override {
343 KJ_SWITCH_ONEOF(state) {
344 KJ_CASE_ONEOF(prober, kj::Own<PathProber>) {
345 state = Ended();
346 }
347 KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) {
348 state = Ended();
349 }
350 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) {
351 // Ugh I guess we need to send a real end() request here.
352 capnpStream.endRequest(MessageSize {2, 0}).send().detach([](kj::Exception&&){});
353 }
354 KJ_CASE_ONEOF(b, Borrowed) {
355 // Fine, ignore.
356 }
357 KJ_CASE_ONEOF(e, Ended) {
358 // Fine, ignore.
359 }
360 }
361 }
362
363 // ---------------------------------------------------------------------------
364 // PathProber
365
366 class PathProber final: public kj::AsyncInputStream {
367 public:
PathProber(CapnpToKjStreamAdapter & parent,kj::Own<kj::AsyncOutputStream> inner,kj::PromiseFulfillerPair<void> paf=kj::newPromiseAndFulfiller<void> ())368 PathProber(CapnpToKjStreamAdapter& parent, kj::Own<kj::AsyncOutputStream> inner,
369 kj::PromiseFulfillerPair<void> paf = kj::newPromiseAndFulfiller<void>())
370 : parent(parent), inner(kj::mv(inner)),
371 readyPromise(paf.promise.fork()),
372 readyFulfiller(kj::mv(paf.fulfiller)),
373 task(nullptr) {}
374
startProbing()375 void startProbing() {
376 task = probeForShorterPath();
377 }
378
setNewParent(CapnpToKjStreamAdapter & newParent)379 void setNewParent(CapnpToKjStreamAdapter& newParent) {
380 KJ_ASSERT(parent == nullptr);
381 parent = newParent;
382 auto paf = kj::newPromiseAndFulfiller<void>();
383 readyPromise = paf.promise.fork();
384 readyFulfiller = kj::mv(paf.fulfiller);
385 }
386
whenReady()387 kj::Promise<void> whenReady() {
388 return readyPromise.addBranch();
389 }
390
pumpToShorterPath(capnp::ByteStream::Client target,uint64_t limit)391 kj::Promise<uint64_t> pumpToShorterPath(capnp::ByteStream::Client target, uint64_t limit) {
392 // If our probe succeeds in finding a KjToCapnpStreamAdapter somewhere down the stack, that
393 // will call this method to provide the shortened path.
394
395 KJ_IF_MAYBE(currentParent, parent) {
396 parent = nullptr;
397
398 auto self = kj::mv(currentParent->state.get<kj::Own<PathProber>>());
399 currentParent->state = Ended(); // temporary, we'll set this properly below
400 KJ_ASSERT(self.get() == this);
401
402 // Open a substream on the target stream.
403 auto req = target.getSubstreamRequest();
404 req.setLimit(limit);
405 auto paf = kj::newPromiseAndFulfiller<uint64_t>();
406 req.setCallback(kj::heap<SubstreamCallbackImpl>(currentParent->factory,
407 kj::mv(self), kj::mv(paf.fulfiller), limit));
408
409 // Now we hook up the incoming stream adapter to point directly to this substream, yay.
410 currentParent->state = req.send().getSubstream();
411
412 // Let the original CapnpToKjStreamAdapter know that it's safe to handle incoming requests.
413 readyFulfiller->fulfill();
414
415 // It's now up to the SubstreamCallbackImpl to signal when the pump is done.
416 return kj::mv(paf.promise);
417 } else {
418 // We already completed a path-shortening. Probably SubstreamCallbackImpl::ended() was
419 // eventually called, meaning the substream was ended without redirecting back to us. So,
420 // we're at EOF.
421 return uint64_t(0);
422 }
423 }
424
tryRead(void * buffer,size_t minBytes,size_t maxBytes)425 kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
426 // If this is called, it means the tryPumpFrom() in probeForShorterPath() eventually invoked
427 // code that tries to read manually from the source. We don't know what this code is doing
428 // exactly, but we do know for sure that the endpoint is not a KjToCapnpStreamAdapter, so
429 // we can't optimize. Instead, we pretend that we immediately hit EOF, ending the pump. This
430 // works because pumps do not propagate EOF -- the destination can still receive further
431 // writes and pumps. Basically our probing pump becomes a no-op, and then we revert to having
432 // each write() RPC directly call write() on the inner stream.
433 return size_t(0);
434 }
435
pumpTo(kj::AsyncOutputStream & output,uint64_t amount)436 kj::Promise<uint64_t> pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override {
437 // Call the stream's `tryPumpFrom()` as a way to discover where the data will eventually go,
438 // in hopes that we find we can shorten the path.
439 KJ_IF_MAYBE(promise, output.tryPumpFrom(*this, amount)) {
440 // tryPumpFrom() returned non-null. Either it called `tryRead()` or `pumpTo()` (see
441 // below), or it plans to do so in the future.
442 return kj::mv(*promise);
443 } else {
444 // There is no shorter path. As with tryRead(), we pretend we get immediate EOF.
445 return uint64_t(0);
446 }
447 }
448
449 private:
450 kj::Maybe<CapnpToKjStreamAdapter&> parent;
451 kj::Own<kj::AsyncOutputStream> inner;
452 kj::ForkedPromise<void> readyPromise;
453 kj::Own<kj::PromiseFulfiller<void>> readyFulfiller;
454 kj::Promise<void> task;
455
456 friend class SubstreamCallbackImpl;
457
probeForShorterPath()458 kj::Promise<void> probeForShorterPath() {
459 return kj::evalNow([&]() -> kj::Promise<uint64_t> {
460 return pumpTo(*inner, kj::maxValue);
461 }).then([this](uint64_t actual) {
462 KJ_IF_MAYBE(currentParent, parent) {
463 KJ_IF_MAYBE(prober, currentParent->state.tryGet<kj::Own<PathProber>>()) {
464 // Either we didn't find any shorter path at all during probing and faked an EOF
465 // to get out of the probe (see comments in tryRead(), or we DID find a shorter path,
466 // completed a pumpTo() using a substream, and that substream redirected back to us,
467 // and THEN we couldn't find any further shorter paths for subsequent pumps.
468
469 // HACK: If we overwrite the Probing state now, we'll delete ourselves and delete
470 // this task promise, which is an error... let the event loop do it later by
471 // detaching.
472 task.attach(kj::mv(*prober)).detach([](kj::Exception&&){});
473 parent = nullptr;
474
475 // OK, now we can change the parent state and signal it to proceed.
476 currentParent->state = kj::mv(inner);
477 readyFulfiller->fulfill();
478 }
479 }
480 }).eagerlyEvaluate([this](kj::Exception&& exception) mutable {
481 // Something threw, so propagate the exception to break the parent.
482 readyFulfiller->reject(kj::mv(exception));
483 });
484 }
485 };
486
487 protected:
488 // ---------------------------------------------------------------------------
489 // implements ByteStream::Server RPC interface
490
shortenPath()491 kj::Maybe<kj::Promise<Capability::Client>> shortenPath() override {
492 return shortenPathImpl();
493 }
shortenPathImpl()494 kj::Promise<Capability::Client> shortenPathImpl() {
495 // Called by RPC implementation to find out if a shorter path presents itself.
496 KJ_SWITCH_ONEOF(state) {
497 KJ_CASE_ONEOF(prober, kj::Own<PathProber>) {
498 return prober->whenReady().then([this]() {
499 KJ_ASSERT(!state.is<kj::Own<PathProber>>());
500 return shortenPathImpl();
501 });
502 }
503 KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) {
504 // No shortening possible. Pretend we never resolve so that calls continue to be routed
505 // to us forever.
506 return kj::NEVER_DONE;
507 }
508 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) {
509 return Capability::Client(capnpStream);
510 }
511 KJ_CASE_ONEOF(b, Borrowed) {
512 KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; }
513 return kj::NEVER_DONE;
514 }
515 KJ_CASE_ONEOF(e, Ended) {
516 // No shortening possible. Pretend we never resolve so that calls continue to be routed
517 // to us forever.
518 return kj::NEVER_DONE;
519 }
520 }
521 KJ_UNREACHABLE;
522 }
523
write(WriteContext context)524 kj::Promise<void> write(WriteContext context) override {
525 KJ_SWITCH_ONEOF(state) {
526 KJ_CASE_ONEOF(prober, kj::Own<PathProber>) {
527 return prober->whenReady().then([this, context]() mutable {
528 KJ_ASSERT(!state.is<kj::Own<PathProber>>());
529 return write(context);
530 });
531 }
532 KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) {
533 auto data = context.getParams().getBytes();
534 return kjStream->write(data.begin(), data.size());
535 }
536 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) {
537 auto params = context.getParams();
538 auto req = capnpStream.writeRequest(params.totalSize());
539 req.setBytes(params.getBytes());
540 return req.send();
541 }
542 KJ_CASE_ONEOF(b, Borrowed) {
543 KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; }
544 return kj::READY_NOW;
545 }
546 KJ_CASE_ONEOF(e, Ended) {
547 KJ_FAIL_REQUIRE("already called end()") { break; }
548 return kj::READY_NOW;
549 }
550 }
551 KJ_UNREACHABLE;
552 }
553
end(EndContext context)554 kj::Promise<void> end(EndContext context) override {
555 KJ_SWITCH_ONEOF(state) {
556 KJ_CASE_ONEOF(prober, kj::Own<PathProber>) {
557 return prober->whenReady().then([this, context]() mutable {
558 KJ_ASSERT(!state.is<kj::Own<PathProber>>());
559 return end(context);
560 });
561 }
562 KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) {
563 // TODO(someday): When KJ adds a proper .end() call, use it here. For now, we must
564 // drop the stream to close it.
565 state = Ended();
566 return kj::READY_NOW;
567 }
568 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) {
569 auto params = context.getParams();
570 auto req = capnpStream.endRequest(params.totalSize());
571 return context.tailCall(kj::mv(req));
572 }
573 KJ_CASE_ONEOF(b, Borrowed) {
574 KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; }
575 return kj::READY_NOW;
576 }
577 KJ_CASE_ONEOF(e, Ended) {
578 KJ_FAIL_REQUIRE("already called end()") { break; }
579 return kj::READY_NOW;
580 }
581 }
582 KJ_UNREACHABLE;
583 }
584
getSubstream(GetSubstreamContext context)585 kj::Promise<void> getSubstream(GetSubstreamContext context) override {
586 KJ_SWITCH_ONEOF(state) {
587 KJ_CASE_ONEOF(prober, kj::Own<PathProber>) {
588 return prober->whenReady().then([this, context]() mutable {
589 KJ_ASSERT(!state.is<kj::Own<PathProber>>());
590 return getSubstream(context);
591 });
592 }
593 KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) {
594 auto params = context.getParams();
595 auto callback = params.getCallback();
596 uint64_t limit = params.getLimit();
597 context.releaseParams();
598
599 auto results = context.initResults(MessageSize {2, 1});
600 results.setSubstream(factory.streamSet.add(kj::heap<SubstreamImpl>(
601 factory, *this, thisCap(), *kjStream, kj::mv(callback), kj::mv(limit))));
602 state = Borrowed { kj::mv(kjStream) };
603 return kj::READY_NOW;
604 }
605 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) {
606 auto params = context.getParams();
607 auto req = capnpStream.getSubstreamRequest(params.totalSize());
608 req.setCallback(params.getCallback());
609 req.setLimit(params.getLimit());
610 return context.tailCall(kj::mv(req));
611 }
612 KJ_CASE_ONEOF(b, Borrowed) {
613 KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; }
614 return kj::READY_NOW;
615 }
616 KJ_CASE_ONEOF(e, Ended) {
617 KJ_FAIL_REQUIRE("already called end()") { break; }
618 return kj::READY_NOW;
619 }
620 }
621 KJ_UNREACHABLE;
622 }
623
624 private:
625 ByteStreamFactory& factory;
626
627 struct Borrowed { kj::Own<kj::AsyncOutputStream> stream; };
628 struct Ended {};
629
630 kj::OneOf<kj::Own<PathProber>, kj::Own<kj::AsyncOutputStream>,
631 capnp::ByteStream::Client, Borrowed, Ended> state;
632
633 class SubstreamCallbackImpl final: public capnp::ByteStream::SubstreamCallback::Server {
634 public:
SubstreamCallbackImpl(ByteStreamFactory & factory,kj::Own<PathProber> pathProber,kj::Own<kj::PromiseFulfiller<uint64_t>> originalPumpfulfiller,uint64_t originalPumpLimit)635 SubstreamCallbackImpl(ByteStreamFactory& factory,
636 kj::Own<PathProber> pathProber,
637 kj::Own<kj::PromiseFulfiller<uint64_t>> originalPumpfulfiller,
638 uint64_t originalPumpLimit)
639 : factory(factory),
640 pathProber(kj::mv(pathProber)),
641 originalPumpfulfiller(kj::mv(originalPumpfulfiller)),
642 originalPumpLimit(originalPumpLimit) {}
643
~SubstreamCallbackImpl()644 ~SubstreamCallbackImpl() noexcept(false) {
645 if (!done) {
646 originalPumpfulfiller->reject(KJ_EXCEPTION(DISCONNECTED,
647 "stream disconnected because SubstreamCallbackImpl was never called back"));
648 }
649 }
650
ended(EndedContext context)651 kj::Promise<void> ended(EndedContext context) override {
652 KJ_REQUIRE(!done);
653 uint64_t actual = context.getParams().getByteCount();
654 KJ_REQUIRE(actual <= originalPumpLimit);
655
656 done = true;
657
658 // EOF before pump completed. Signal a short pump.
659 originalPumpfulfiller->fulfill(context.getParams().getByteCount());
660
661 // Give the original pump task a chance to finish up.
662 return pathProber->task.attach(kj::mv(pathProber));
663 }
664
reachedLimit(ReachedLimitContext context)665 kj::Promise<void> reachedLimit(ReachedLimitContext context) override {
666 KJ_REQUIRE(!done);
667 done = true;
668
669 // Allow the shortened stream to redirect back to our original underlying stream.
670 auto results = context.getResults(capnp::MessageSize { 4, 1 });
671 results.setNext(factory.streamSet.add(
672 kj::heap<CapnpToKjStreamAdapter>(factory, kj::mv(pathProber))));
673
674 // The full pump completed. Note that it's important that we fulfill this after the
675 // PathProber has been attached to the new CapnpToKjStreamAdapter, which will have happened
676 // in CapnpToKjStreamAdapter's constructor, which calls pathProber->setNewParent().
677 originalPumpfulfiller->fulfill(kj::cp(originalPumpLimit));
678
679 return kj::READY_NOW;
680 }
681
682 private:
683 ByteStreamFactory& factory;
684 kj::Own<PathProber> pathProber;
685 kj::Own<kj::PromiseFulfiller<uint64_t>> originalPumpfulfiller;
686 uint64_t originalPumpLimit;
687 bool done = false;
688 };
689 };
690
691 // =======================================================================================
692
693 class ByteStreamFactory::KjToCapnpStreamAdapter final: public kj::AsyncOutputStream {
694 public:
KjToCapnpStreamAdapter(ByteStreamFactory & factory,capnp::ByteStream::Client innerParam)695 KjToCapnpStreamAdapter(ByteStreamFactory& factory, capnp::ByteStream::Client innerParam)
696 : factory(factory),
697 inner(kj::mv(innerParam)),
698 findShorterPathTask(findShorterPath(inner).fork()) {}
699
~KjToCapnpStreamAdapter()700 ~KjToCapnpStreamAdapter() noexcept(false) {
701 // HACK: KJ streams are implicitly ended on destruction, but the RPC stream needs a call. We
702 // use a detached promise for now, which is probably OK since capabilities are refcounted and
703 // asynchronously destroyed anyway.
704 // TODO(cleanup): Fix this when KJ streads add an explicit end() method.
705 KJ_IF_MAYBE(o, optimized) {
706 o->directEnd();
707 } else {
708 inner.endRequest(MessageSize {2, 0}).send().detach([](kj::Exception&&){});
709 }
710 }
711
write(const void * buffer,size_t size)712 kj::Promise<void> write(const void* buffer, size_t size) override {
713 KJ_SWITCH_ONEOF(getShortestPath()) {
714 KJ_CASE_ONEOF(promise, kj::Promise<void>) {
715 return promise.then([this,buffer,size]() {
716 return write(buffer, size);
717 });
718 }
719 KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) {
720 auto limit = kj::min(kjStream.limit, MAX_BYTES_PER_WRITE);
721 if (size <= limit) {
722 auto promise = kjStream.stream.write(buffer, size);
723 return promise.then([kjStream,size]() mutable {
724 kjStream.lender.returnStream(size);
725 });
726 } else {
727 auto promise = kjStream.stream.write(buffer, limit);
728 return promise.then([this,kjStream,buffer,size,limit]() mutable {
729 kjStream.lender.returnStream(limit);
730 return write(reinterpret_cast<const byte*>(buffer) + limit,
731 size - limit);
732 });
733 }
734 }
735 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) {
736 if (size <= MAX_BYTES_PER_WRITE) {
737 auto req = capnpStream->writeRequest(MessageSize { 8 + size / sizeof(word), 0 });
738 req.setBytes(kj::arrayPtr(reinterpret_cast<const byte*>(buffer), size));
739 return req.send();
740 } else {
741 auto req = capnpStream->writeRequest(
742 MessageSize { 8 + MAX_BYTES_PER_WRITE / sizeof(word), 0 });
743 req.setBytes(kj::arrayPtr(reinterpret_cast<const byte*>(buffer), MAX_BYTES_PER_WRITE));
744 return req.send().then([this,buffer,size]() mutable {
745 return write(reinterpret_cast<const byte*>(buffer) + MAX_BYTES_PER_WRITE,
746 size - MAX_BYTES_PER_WRITE);
747 });
748 }
749 }
750 }
751 KJ_UNREACHABLE;
752 }
753
write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces)754 kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override {
755 KJ_SWITCH_ONEOF(getShortestPath()) {
756 KJ_CASE_ONEOF(promise, kj::Promise<void>) {
757 return promise.then([this,pieces]() {
758 return write(pieces);
759 });
760 }
761 KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) {
762 size_t size = 0;
763 for (auto& piece: pieces) { size += piece.size(); }
764 auto limit = kj::min(kjStream.limit, MAX_BYTES_PER_WRITE);
765 if (size <= limit) {
766 auto promise = kjStream.stream.write(pieces);
767 return promise.then([kjStream,size]() mutable {
768 kjStream.lender.returnStream(size);
769 });
770 } else {
771 // ughhhhhhhhhh, we need to split the pieces.
772 return splitAndWrite(pieces, kjStream.limit,
773 [kjStream,limit](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) mutable {
774 return kjStream.stream.write(pieces).then([kjStream,limit]() mutable {
775 kjStream.lender.returnStream(limit);
776 });
777 });
778 }
779 }
780 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) {
781 auto writePieces = [capnpStream](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) {
782 size_t size = 0;
783 for (auto& piece: pieces) size += piece.size();
784 auto req = capnpStream->writeRequest(MessageSize { 8 + size / sizeof(word), 0 });
785 auto out = req.initBytes(size);
786 byte* ptr = out.begin();
787 for (auto& piece: pieces) {
788 memcpy(ptr, piece.begin(), piece.size());
789 ptr += piece.size();
790 }
791 KJ_ASSERT(ptr == out.end());
792 return req.send();
793 };
794
795 size_t size = 0;
796 for (auto& piece: pieces) size += piece.size();
797 if (size <= MAX_BYTES_PER_WRITE) {
798 return writePieces(pieces);
799 } else {
800 // ughhhhhhhhhh, we need to split the pieces.
801 return splitAndWrite(pieces, MAX_BYTES_PER_WRITE, writePieces);
802 }
803 }
804 }
805 KJ_UNREACHABLE;
806 }
807
tryPumpFrom(kj::AsyncInputStream & input,uint64_t amount=kj::maxValue)808 kj::Maybe<kj::Promise<uint64_t>> tryPumpFrom(
809 kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override {
810 KJ_IF_MAYBE(rpc, kj::dynamicDowncastIfAvailable<CapnpToKjStreamAdapter::PathProber>(input)) {
811 // Oh interesting, it turns we're hosting an incoming ByteStream which is pumping to this
812 // outgoing ByteStream. We can let the Cap'n Proto RPC layer know that it can shorten the
813 // path from one to the other.
814 return rpc->pumpToShorterPath(inner, amount);
815 } else {
816 return pumpLoop(input, 0, amount);
817 }
818 }
819
whenWriteDisconnected()820 kj::Promise<void> whenWriteDisconnected() override {
821 return findShorterPathTask.addBranch();
822 }
823
824 private:
825 ByteStreamFactory& factory;
826 capnp::ByteStream::Client inner;
827 kj::Maybe<StreamServerBase&> optimized;
828
829 kj::ForkedPromise<void> findShorterPathTask;
830 // This serves two purposes:
831 // 1. Waits for the capability to resolve (if it is a promise), and then shortens the path if
832 // possible.
833 // 2. Implements whenWriteDisconnected().
834
findShorterPath(capnp::ByteStream::Client & capnpClient)835 kj::Promise<void> findShorterPath(capnp::ByteStream::Client& capnpClient) {
836 // If the capnp stream turns out to resolve back to this process, shorten the path.
837 // Also, implement whenWriteDisconnected() based on this.
838 return factory.streamSet.getLocalServer(capnpClient)
839 .then([this](kj::Maybe<capnp::ByteStream::Server&> server) -> kj::Promise<void> {
840 KJ_IF_MAYBE(s, server) {
841 // Yay, we discovered that the ByteStream actually points back to a local KJ stream.
842 // We can use this to shorten the path by skipping the RPC machinery.
843 return findShorterPath(kj::downcast<StreamServerBase>(*s));
844 } else {
845 // The capability is fully-resolved. This suggests that the remote implementation is
846 // NOT a CapnpToKjStreamAdapter at all, because CapnpToKjStreamAdapter is designed to
847 // always look like a promise. It's some other implementation that doesn't present
848 // itself as a promise. We have no way to detect when it is disconnected.
849 return kj::NEVER_DONE;
850 }
851 }, [](kj::Exception&& e) -> kj::Promise<void> {
852 // getLocalServer() thrown when the capability is a promise cap that rejects. We can
853 // use this to implement whenWriteDisconnected().
854 //
855 // (Note that because this exception handler is passed to the .then(), it does NOT catch
856 // eoxceptions thrown by the success handler immediately above it. This handler will ONLY
857 // catch exceptions from getLocalServer() itself.)
858 return kj::READY_NOW;
859 });
860 }
861
findShorterPath(StreamServerBase & capnpServer)862 kj::Promise<void> findShorterPath(StreamServerBase& capnpServer) {
863 // We found a shorter path back to this process. Record it.
864 optimized = capnpServer;
865
866 KJ_SWITCH_ONEOF(capnpServer.getShortestPath()) {
867 KJ_CASE_ONEOF(promise, kj::Promise<void>) {
868 return promise.then([this,&capnpServer]() {
869 return findShorterPath(capnpServer);
870 });
871 }
872 KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) {
873 // The ByteStream::Server wraps a regular KJ stream that does not wrap another capnp
874 // stream.
875 if (kjStream.limit < (uint64_t)kj::maxValue / 2) {
876 // But it isn't wrapping that stream forever. Eventually it plans to redirect back to
877 // some other stream. So, let's wait for that, and possibly shorten again.
878 kjStream.lender.returnStream(0);
879 return KJ_ASSERT_NONNULL(capnpServer.shortenPath())
880 .then([this, &capnpServer](auto&&) {
881 return findShorterPath(capnpServer);
882 });
883 } else {
884 // This KJ stream is (effectively) the permanent endpoint. We can't get any shorter
885 // from here. All we want to do now is watch for disconnect.
886 auto promise = kjStream.stream.whenWriteDisconnected();
887 kjStream.lender.returnStream(0);
888 return promise;
889 }
890 }
891 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) {
892 return findShorterPath(*capnpStream);
893 }
894 }
895 KJ_UNREACHABLE;
896 }
897
getShortestPath()898 StreamServerBase::ShortestPath getShortestPath() {
899 KJ_IF_MAYBE(o, optimized) {
900 return o->getShortestPath();
901 } else {
902 return &inner;
903 }
904 }
905
pumpLoop(kj::AsyncInputStream & input,uint64_t completed,uint64_t remaining)906 kj::Promise<uint64_t> pumpLoop(kj::AsyncInputStream& input,
907 uint64_t completed, uint64_t remaining) {
908 if (remaining == 0) return completed;
909
910 KJ_SWITCH_ONEOF(getShortestPath()) {
911 KJ_CASE_ONEOF(promise, kj::Promise<void>) {
912 return promise.then([this,&input,completed,remaining]() {
913 return pumpLoop(input,completed,remaining);
914 });
915 }
916 KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) {
917 // Oh hell yes, this capability actually points back to a stream in our own thread. We can
918 // stop sending RPCs and just pump directly.
919
920 if (remaining <= kjStream.limit) {
921 return input.pumpTo(kjStream.stream, remaining)
922 .then([kjStream,completed](uint64_t actual) {
923 kjStream.lender.returnStream(actual);
924 return actual + completed;
925 });
926 } else {
927 auto promise = input.pumpTo(kjStream.stream, kjStream.limit);
928 return promise.then([this,&input,completed,remaining,kjStream]
929 (uint64_t actual) mutable -> kj::Promise<uint64_t> {
930 kjStream.lender.returnStream(actual);
931 if (actual < kjStream.limit) {
932 // EOF reached.
933 return completed + actual;
934 } else {
935 return pumpLoop(input, completed + actual, remaining - actual);
936 }
937 });
938 }
939 }
940 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) {
941 // Pumping from some other kind of steram. Optimize the pump by reading from the input
942 // directly into outgoing RPC messages.
943 size_t size = kj::min(remaining, 8192);
944 auto req = capnpStream->writeRequest(MessageSize { 8 + size / sizeof(word) });
945
946 auto orphanage = Orphanage::getForMessageContaining(
947 capnp::ByteStream::WriteParams::Builder(req));
948
949 auto buffer = orphanage.newOrphan<Data>(size);
950
951 struct WriteRequestAndBuffer {
952 // The order of construction/destruction of lambda captures is unspecified, but we care
953 // about ordering between these two things that we want to capture, so... we need a
954 // struct.
955 StreamingRequest<capnp::ByteStream::WriteParams> request;
956 Orphan<Data> buffer; // points into `request`...
957 };
958
959 WriteRequestAndBuffer wrab = { kj::mv(req), kj::mv(buffer) };
960
961 return input.tryRead(wrab.buffer.get().begin(), 1, size)
962 .then([this, &input, completed, remaining, size, wrab = kj::mv(wrab)]
963 (size_t actual) mutable -> kj::Promise<uint64_t> {
964 if (actual == 0) {
965 return completed;
966 } if (actual < size) {
967 wrab.buffer.truncate(actual);
968 }
969
970 wrab.request.adoptBytes(kj::mv(wrab.buffer));
971 return wrab.request.send()
972 .then([this, &input, completed, remaining, actual]() {
973 return pumpLoop(input, completed + actual, remaining - actual);
974 });
975 });
976 }
977 }
978 KJ_UNREACHABLE;
979 }
980
981 template <typename WritePieces>
splitAndWrite(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces,size_t limit,WritePieces && writeFirstPieces)982 kj::Promise<void> splitAndWrite(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces,
983 size_t limit, WritePieces&& writeFirstPieces) {
984 size_t splitByte = limit;
985 size_t splitPiece = 0;
986 while (pieces[splitPiece].size() <= splitByte) {
987 splitByte -= pieces[splitPiece].size();
988 ++splitPiece;
989 }
990
991 if (splitByte == 0) {
992 // Oh thank god, the split is between two pieces.
993 auto rest = pieces.slice(splitPiece, pieces.size());
994 return writeFirstPieces(pieces.slice(0, splitPiece))
995 .then([this,rest]() mutable {
996 return write(rest);
997 });
998 } else {
999 // FUUUUUUUU---- we need to split one of the pieces in two.
1000 auto left = kj::heapArray<kj::ArrayPtr<const byte>>(splitPiece + 1);
1001 auto right = kj::heapArray<kj::ArrayPtr<const byte>>(pieces.size() - splitPiece);
1002 for (auto i: kj::zeroTo(splitPiece)) {
1003 left[i] = pieces[i];
1004 }
1005 for (auto i: kj::zeroTo(right.size())) {
1006 right[i] = pieces[splitPiece + i];
1007 }
1008 left.back() = pieces[splitPiece].slice(0, splitByte);
1009 right.front() = pieces[splitPiece].slice(splitByte, pieces[splitPiece].size());
1010
1011 return writeFirstPieces(left).attach(kj::mv(left))
1012 .then([this,right=kj::mv(right)]() mutable {
1013 return write(right).attach(kj::mv(right));
1014 });
1015 }
1016 }
1017 };
1018
1019 // =======================================================================================
1020
kjToCapnp(kj::Own<kj::AsyncOutputStream> kjStream)1021 capnp::ByteStream::Client ByteStreamFactory::kjToCapnp(kj::Own<kj::AsyncOutputStream> kjStream) {
1022 return streamSet.add(kj::heap<CapnpToKjStreamAdapter>(*this, kj::mv(kjStream)));
1023 }
1024
capnpToKj(capnp::ByteStream::Client capnpStream)1025 kj::Own<kj::AsyncOutputStream> ByteStreamFactory::capnpToKj(capnp::ByteStream::Client capnpStream) {
1026 return kj::heap<KjToCapnpStreamAdapter>(*this, kj::mv(capnpStream));
1027 }
1028
1029 } // namespace capnp
1030