1 // Copyright (c) 2019 Cloudflare, Inc. and contributors
2 // Licensed under the MIT License:
3 //
4 // Permission is hereby granted, free of charge, to any person obtaining a copy
5 // of this software and associated documentation files (the "Software"), to deal
6 // in the Software without restriction, including without limitation the rights
7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 // copies of the Software, and to permit persons to whom the Software is
9 // furnished to do so, subject to the following conditions:
10 //
11 // The above copyright notice and this permission notice shall be included in
12 // all copies or substantial portions of the Software.
13 //
14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 // THE SOFTWARE.
21 
22 #include "byte-stream.h"
23 #include <kj/one-of.h>
24 #include <kj/debug.h>
25 
26 namespace capnp {
27 
28 const uint MAX_BYTES_PER_WRITE = 1 << 16;
29 
30 class ByteStreamFactory::StreamServerBase: public capnp::ByteStream::Server {
31 public:
32   virtual void returnStream(uint64_t written) = 0;
33   // Called after the StreamServerBase's internal kj::AsyncOutputStream has been borrowed, to
34   // indicate that the borrower is done.
35   //
36   // A stream becomes borrowed either when getShortestPath() returns a BorrowedStream, or when
37   // a SubstreamImpl is constructed wrapping an existing stream.
38 
39   struct BorrowedStream {
40     // Represents permission to use the StreamServerBase's inner AsyncOutputStream directly, up
41     // to some limit of bytes written.
42 
43     StreamServerBase& lender;
44     kj::AsyncOutputStream& stream;
45     uint64_t limit;
46   };
47 
48   typedef kj::OneOf<kj::Promise<void>, capnp::ByteStream::Client*, BorrowedStream> ShortestPath;
49 
50   virtual ShortestPath getShortestPath() = 0;
51   // Called by KjToCapnpStreamAdapter when it has determined that its inner ByteStream::Client
52   // actually points back to a StreamServerBase in the same process created by the same
53   // ByteStreamFactory. Returns the best shortened path to use, or a promise that resolves when the
54   // shortest path is known.
55 
56   virtual void directEnd() = 0;
57   // Called by KjToCapnpStreamAdapter's destructor when it has determined that its inner
58   // ByteStream::Client actually points back to a StreamServerBase in the same process created by
59   // the same ByteStreamFactory. Since destruction of a KJ stream signals EOF, we need to propagate
60   // that by destroying our underlying stream.
61   // TODO(cleanup): When KJ streams evolve an end() method, this can go away.
62 };
63 
64 class ByteStreamFactory::SubstreamImpl final: public StreamServerBase {
65 public:
SubstreamImpl(ByteStreamFactory & factory,StreamServerBase & parent,capnp::ByteStream::Client ownParent,kj::AsyncOutputStream & stream,capnp::ByteStream::SubstreamCallback::Client callback,uint64_t limit,kj::PromiseFulfillerPair<void> paf=kj::newPromiseAndFulfiller<void> ())66   SubstreamImpl(ByteStreamFactory& factory,
67                 StreamServerBase& parent,
68                 capnp::ByteStream::Client ownParent,
69                 kj::AsyncOutputStream& stream,
70                 capnp::ByteStream::SubstreamCallback::Client callback,
71                 uint64_t limit,
72                 kj::PromiseFulfillerPair<void> paf = kj::newPromiseAndFulfiller<void>())
73       : factory(factory),
74         state(Streaming {parent, kj::mv(ownParent), stream, kj::mv(callback)}),
75         limit(limit),
76         resolveFulfiller(kj::mv(paf.fulfiller)),
77         resolvePromise(paf.promise.fork()) {}
78 
79   // ---------------------------------------------------------------------------
80   // implements StreamServerBase
81 
returnStream(uint64_t written)82   void returnStream(uint64_t written) override {
83     completed += written;
84     KJ_ASSERT(completed <= limit);
85     auto borrowed = kj::mv(state.get<Borrowed>());
86     state = kj::mv(borrowed.originalState);
87 
88     if (completed == limit) {
89       limitReached();
90     }
91   }
92 
getShortestPath()93   ShortestPath getShortestPath() override {
94     KJ_SWITCH_ONEOF(state) {
95       KJ_CASE_ONEOF(redirected, Redirected) {
96         return &redirected.replacement;
97       }
98       KJ_CASE_ONEOF(e, Ended) {
99         KJ_FAIL_REQUIRE("already called end()");
100       }
101       KJ_CASE_ONEOF(b, Borrowed) {
102         KJ_FAIL_REQUIRE("can't call other methods while substream is active");
103       }
104       KJ_CASE_ONEOF(streaming, Streaming) {
105         auto& stream = streaming.stream;
106         auto oldState = kj::mv(streaming);
107         state = Borrowed { kj::mv(oldState) };
108         return BorrowedStream { *this, stream, limit - completed };
109       }
110     }
111     KJ_UNREACHABLE;
112   }
113 
directEnd()114   void directEnd() override {
115     KJ_SWITCH_ONEOF(state) {
116       KJ_CASE_ONEOF(redirected, Redirected) {
117         // Ugh I guess we need to send a real end() request here.
118         redirected.replacement.endRequest(MessageSize {2, 0}).send().detach([](kj::Exception&&){});
119       }
120       KJ_CASE_ONEOF(e, Ended) {
121         // whatever
122       }
123       KJ_CASE_ONEOF(b, Borrowed) {
124         // ... whatever.
125       }
126       KJ_CASE_ONEOF(streaming, Streaming) {
127         auto req = streaming.callback.endedRequest(MessageSize {4, 0});
128         req.setByteCount(completed);
129         req.send().detach([](kj::Exception&&){});
130         streaming.parent.returnStream(completed);
131         state = Ended();
132       }
133     }
134   }
135 
136   // ---------------------------------------------------------------------------
137   // implements ByteStream::Server RPC interface
138 
shortenPath()139   kj::Maybe<kj::Promise<Capability::Client>> shortenPath() override {
140     return resolvePromise.addBranch()
141         .then([this]() -> Capability::Client {
142       return state.get<Redirected>().replacement;
143     });
144   }
145 
write(WriteContext context)146   kj::Promise<void> write(WriteContext context) override {
147     auto params = context.getParams();
148     auto data = params.getBytes();
149 
150     KJ_SWITCH_ONEOF(state) {
151       KJ_CASE_ONEOF(redirected, Redirected) {
152         auto req = redirected.replacement.writeRequest(params.totalSize());
153         req.setBytes(data);
154         return req.send();
155       }
156       KJ_CASE_ONEOF(e, Ended) {
157         KJ_FAIL_REQUIRE("already called end()");
158       }
159       KJ_CASE_ONEOF(b, Borrowed) {
160         KJ_FAIL_REQUIRE("can't call other methods while stream is borrowed");
161       }
162       KJ_CASE_ONEOF(streaming, Streaming) {
163         if (completed + data.size() < limit) {
164           completed += data.size();
165           return streaming.stream.write(data.begin(), data.size());
166         } else {
167           // This write passes the limit.
168           uint64_t remainder = limit - completed;
169           auto leftover = data.slice(remainder, data.size());
170           return streaming.stream.write(data.begin(), remainder)
171               .then([this, leftover]() -> kj::Promise<void> {
172             completed = limit;
173             limitReached();
174 
175             if (leftover.size() > 0) {
176               // Need to forward the leftover bytes to the next stream.
177               auto req = state.get<Redirected>().replacement.writeRequest(
178                   MessageSize { 4 + leftover.size() / sizeof(capnp::word), 0 });
179               req.setBytes(leftover);
180               return req.send();
181             } else {
182               return kj::READY_NOW;
183             }
184           });
185         }
186       }
187     }
188     KJ_UNREACHABLE;
189   }
190 
end(EndContext context)191   kj::Promise<void> end(EndContext context) override {
192     KJ_SWITCH_ONEOF(state) {
193       KJ_CASE_ONEOF(redirected, Redirected) {
194         return context.tailCall(redirected.replacement.endRequest(MessageSize {2,0}));
195       }
196       KJ_CASE_ONEOF(e, Ended) {
197         KJ_FAIL_REQUIRE("already called end()");
198       }
199       KJ_CASE_ONEOF(b, Borrowed) {
200         KJ_FAIL_REQUIRE("can't call other methods while stream is borrowed");
201       }
202       KJ_CASE_ONEOF(streaming, Streaming) {
203         auto req = streaming.callback.endedRequest(MessageSize {4, 0});
204         req.setByteCount(completed);
205         auto result = req.send().ignoreResult();
206         streaming.parent.returnStream(completed);
207         state = Ended();
208         return result;
209       }
210     }
211     KJ_UNREACHABLE;
212   }
213 
getSubstream(GetSubstreamContext context)214   kj::Promise<void> getSubstream(GetSubstreamContext context) override {
215     KJ_SWITCH_ONEOF(state) {
216       KJ_CASE_ONEOF(redirected, Redirected) {
217         auto params = context.getParams();
218         auto req = redirected.replacement.getSubstreamRequest(params.totalSize());
219         req.setCallback(params.getCallback());
220         req.setLimit(params.getLimit());
221         return context.tailCall(kj::mv(req));
222       }
223       KJ_CASE_ONEOF(e, Ended) {
224         KJ_FAIL_REQUIRE("already called end()");
225       }
226       KJ_CASE_ONEOF(b, Borrowed) {
227         KJ_FAIL_REQUIRE("can't call other methods while stream is borrowed");
228       }
229       KJ_CASE_ONEOF(streaming, Streaming) {
230         auto params = context.getParams();
231         auto callback = params.getCallback();
232         auto limit = params.getLimit();
233         context.releaseParams();
234         auto results = context.getResults(MessageSize { 2, 1 });
235         results.setSubstream(factory.streamSet.add(kj::heap<SubstreamImpl>(
236             factory, *this, thisCap(), streaming.stream, kj::mv(callback), kj::mv(limit))));
237         state = Borrowed { kj::mv(streaming) };
238         return kj::READY_NOW;
239       }
240     }
241     KJ_UNREACHABLE;
242   }
243 
244 private:
245   ByteStreamFactory& factory;
246 
247   struct Streaming {
248     StreamServerBase& parent;
249     capnp::ByteStream::Client ownParent;
250     kj::AsyncOutputStream& stream;
251     capnp::ByteStream::SubstreamCallback::Client callback;
252   };
253   struct Borrowed {
254     Streaming originalState;
255   };
256   struct Redirected {
257     capnp::ByteStream::Client replacement;
258   };
259   struct Ended {};
260 
261   kj::OneOf<Streaming, Borrowed, Redirected, Ended> state;
262 
263   uint64_t limit;
264   uint64_t completed = 0;
265 
266   kj::Own<kj::PromiseFulfiller<void>> resolveFulfiller;
267   kj::ForkedPromise<void> resolvePromise;
268 
limitReached()269   void limitReached() {
270     auto& streaming = state.get<Streaming>();
271     auto next = streaming.callback.reachedLimitRequest(capnp::MessageSize {2,0})
272         .send().getNext();
273 
274     // Set the next stream as our replacement.
275     streaming.parent.returnStream(limit);
276     state = Redirected { kj::mv(next) };
277     resolveFulfiller->fulfill();
278   }
279 };
280 
281 // =======================================================================================
282 
283 class ByteStreamFactory::CapnpToKjStreamAdapter final: public StreamServerBase {
284   // Implements Cap'n Proto ByteStream as a wrapper around a KJ stream.
285 
286   class SubstreamCallbackImpl;
287 
288 public:
289   class PathProber;
290 
CapnpToKjStreamAdapter(ByteStreamFactory & factory,kj::Own<kj::AsyncOutputStream> inner)291   CapnpToKjStreamAdapter(ByteStreamFactory& factory,
292                          kj::Own<kj::AsyncOutputStream> inner)
293       : factory(factory),
294         state(kj::heap<PathProber>(*this, kj::mv(inner))) {
295     state.get<kj::Own<PathProber>>()->startProbing();
296   }
297 
CapnpToKjStreamAdapter(ByteStreamFactory & factory,kj::Own<PathProber> pathProber)298   CapnpToKjStreamAdapter(ByteStreamFactory& factory,
299                          kj::Own<PathProber> pathProber)
300       : factory(factory),
301         state(kj::mv(pathProber)) {
302     state.get<kj::Own<PathProber>>()->setNewParent(*this);
303   }
304 
305   // ---------------------------------------------------------------------------
306   // implements StreamServerBase
307 
returnStream(uint64_t written)308   void returnStream(uint64_t written) override {
309     auto stream = kj::mv(state.get<Borrowed>().stream);
310     state = kj::mv(stream);
311   }
312 
getShortestPath()313   ShortestPath getShortestPath() override {
314     // Called by KjToCapnpStreamAdapter when it has determined that its inner ByteStream::Client
315     // actually points back to a CapnpToKjStreamAdapter in the same process. Returns the best
316     // shortened path to use, or a promise that resolves when the shortest path is known.
317 
318     KJ_SWITCH_ONEOF(state) {
319       KJ_CASE_ONEOF(prober, kj::Own<PathProber>) {
320         return prober->whenReady();
321       }
322       KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) {
323         auto& streamRef = *kjStream;
324         state = Borrowed { kj::mv(kjStream) };
325         return StreamServerBase::BorrowedStream { *this, streamRef, kj::maxValue };
326       }
327       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) {
328         return &capnpStream;
329       }
330       KJ_CASE_ONEOF(b, Borrowed) {
331         KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; }
332         return kj::Promise<void>(kj::READY_NOW);
333       }
334       KJ_CASE_ONEOF(e, Ended) {
335         KJ_FAIL_REQUIRE("already ended") { break; }
336         return kj::Promise<void>(kj::READY_NOW);
337       }
338     }
339     KJ_UNREACHABLE;
340   }
341 
directEnd()342   void directEnd() override {
343     KJ_SWITCH_ONEOF(state) {
344       KJ_CASE_ONEOF(prober, kj::Own<PathProber>) {
345         state = Ended();
346       }
347       KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) {
348         state = Ended();
349       }
350       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) {
351         // Ugh I guess we need to send a real end() request here.
352         capnpStream.endRequest(MessageSize {2, 0}).send().detach([](kj::Exception&&){});
353       }
354       KJ_CASE_ONEOF(b, Borrowed) {
355         // Fine, ignore.
356       }
357       KJ_CASE_ONEOF(e, Ended) {
358         // Fine, ignore.
359       }
360     }
361   }
362 
363   // ---------------------------------------------------------------------------
364   // PathProber
365 
366   class PathProber final: public kj::AsyncInputStream {
367   public:
PathProber(CapnpToKjStreamAdapter & parent,kj::Own<kj::AsyncOutputStream> inner,kj::PromiseFulfillerPair<void> paf=kj::newPromiseAndFulfiller<void> ())368     PathProber(CapnpToKjStreamAdapter& parent, kj::Own<kj::AsyncOutputStream> inner,
369                kj::PromiseFulfillerPair<void> paf = kj::newPromiseAndFulfiller<void>())
370         : parent(parent), inner(kj::mv(inner)),
371           readyPromise(paf.promise.fork()),
372           readyFulfiller(kj::mv(paf.fulfiller)),
373           task(nullptr) {}
374 
startProbing()375     void startProbing() {
376       task = probeForShorterPath();
377     }
378 
setNewParent(CapnpToKjStreamAdapter & newParent)379     void setNewParent(CapnpToKjStreamAdapter& newParent) {
380       KJ_ASSERT(parent == nullptr);
381       parent = newParent;
382       auto paf = kj::newPromiseAndFulfiller<void>();
383       readyPromise = paf.promise.fork();
384       readyFulfiller = kj::mv(paf.fulfiller);
385     }
386 
whenReady()387     kj::Promise<void> whenReady() {
388       return readyPromise.addBranch();
389     }
390 
pumpToShorterPath(capnp::ByteStream::Client target,uint64_t limit)391     kj::Promise<uint64_t> pumpToShorterPath(capnp::ByteStream::Client target, uint64_t limit) {
392       // If our probe succeeds in finding a KjToCapnpStreamAdapter somewhere down the stack, that
393       // will call this method to provide the shortened path.
394 
395       KJ_IF_MAYBE(currentParent, parent) {
396         parent = nullptr;
397 
398         auto self = kj::mv(currentParent->state.get<kj::Own<PathProber>>());
399         currentParent->state = Ended();  // temporary, we'll set this properly below
400         KJ_ASSERT(self.get() == this);
401 
402         // Open a substream on the target stream.
403         auto req = target.getSubstreamRequest();
404         req.setLimit(limit);
405         auto paf = kj::newPromiseAndFulfiller<uint64_t>();
406         req.setCallback(kj::heap<SubstreamCallbackImpl>(currentParent->factory,
407             kj::mv(self), kj::mv(paf.fulfiller), limit));
408 
409         // Now we hook up the incoming stream adapter to point directly to this substream, yay.
410         currentParent->state = req.send().getSubstream();
411 
412         // Let the original CapnpToKjStreamAdapter know that it's safe to handle incoming requests.
413         readyFulfiller->fulfill();
414 
415         // It's now up to the SubstreamCallbackImpl to signal when the pump is done.
416         return kj::mv(paf.promise);
417       } else {
418         // We already completed a path-shortening. Probably SubstreamCallbackImpl::ended() was
419         // eventually called, meaning the substream was ended without redirecting back to us. So,
420         // we're at EOF.
421         return uint64_t(0);
422       }
423     }
424 
tryRead(void * buffer,size_t minBytes,size_t maxBytes)425     kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
426       // If this is called, it means the tryPumpFrom() in probeForShorterPath() eventually invoked
427       // code that tries to read manually from the source. We don't know what this code is doing
428       // exactly, but we do know for sure that the endpoint is not a KjToCapnpStreamAdapter, so
429       // we can't optimize. Instead, we pretend that we immediately hit EOF, ending the pump. This
430       // works because pumps do not propagate EOF -- the destination can still receive further
431       // writes and pumps. Basically our probing pump becomes a no-op, and then we revert to having
432       // each write() RPC directly call write() on the inner stream.
433       return size_t(0);
434     }
435 
pumpTo(kj::AsyncOutputStream & output,uint64_t amount)436     kj::Promise<uint64_t> pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override {
437       // Call the stream's `tryPumpFrom()` as a way to discover where the data will eventually go,
438       // in hopes that we find we can shorten the path.
439       KJ_IF_MAYBE(promise, output.tryPumpFrom(*this, amount)) {
440         // tryPumpFrom() returned non-null. Either it called `tryRead()` or `pumpTo()` (see
441         // below), or it plans to do so in the future.
442         return kj::mv(*promise);
443       } else {
444         // There is no shorter path. As with tryRead(), we pretend we get immediate EOF.
445         return uint64_t(0);
446       }
447     }
448 
449   private:
450     kj::Maybe<CapnpToKjStreamAdapter&> parent;
451     kj::Own<kj::AsyncOutputStream> inner;
452     kj::ForkedPromise<void> readyPromise;
453     kj::Own<kj::PromiseFulfiller<void>> readyFulfiller;
454     kj::Promise<void> task;
455 
456     friend class SubstreamCallbackImpl;
457 
probeForShorterPath()458     kj::Promise<void> probeForShorterPath() {
459       return kj::evalNow([&]() -> kj::Promise<uint64_t> {
460         return pumpTo(*inner, kj::maxValue);
461       }).then([this](uint64_t actual) {
462         KJ_IF_MAYBE(currentParent, parent) {
463           KJ_IF_MAYBE(prober, currentParent->state.tryGet<kj::Own<PathProber>>()) {
464             // Either we didn't find any shorter path at all during probing and faked an EOF
465             // to get out of the probe (see comments in tryRead(), or we DID find a shorter path,
466             // completed a pumpTo() using a substream, and that substream redirected back to us,
467             // and THEN we couldn't find any further shorter paths for subsequent pumps.
468 
469             // HACK: If we overwrite the Probing state now, we'll delete ourselves and delete
470             //   this task promise, which is an error... let the event loop do it later by
471             //   detaching.
472             task.attach(kj::mv(*prober)).detach([](kj::Exception&&){});
473             parent = nullptr;
474 
475             // OK, now we can change the parent state and signal it to proceed.
476             currentParent->state = kj::mv(inner);
477             readyFulfiller->fulfill();
478           }
479         }
480       }).eagerlyEvaluate([this](kj::Exception&& exception) mutable {
481         // Something threw, so propagate the exception to break the parent.
482         readyFulfiller->reject(kj::mv(exception));
483       });
484     }
485   };
486 
487 protected:
488   // ---------------------------------------------------------------------------
489   // implements ByteStream::Server RPC interface
490 
shortenPath()491   kj::Maybe<kj::Promise<Capability::Client>> shortenPath() override {
492     return shortenPathImpl();
493   }
shortenPathImpl()494   kj::Promise<Capability::Client> shortenPathImpl() {
495     // Called by RPC implementation to find out if a shorter path presents itself.
496     KJ_SWITCH_ONEOF(state) {
497       KJ_CASE_ONEOF(prober, kj::Own<PathProber>) {
498         return prober->whenReady().then([this]() {
499           KJ_ASSERT(!state.is<kj::Own<PathProber>>());
500           return shortenPathImpl();
501         });
502       }
503       KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) {
504         // No shortening possible. Pretend we never resolve so that calls continue to be routed
505         // to us forever.
506         return kj::NEVER_DONE;
507       }
508       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) {
509         return Capability::Client(capnpStream);
510       }
511       KJ_CASE_ONEOF(b, Borrowed) {
512         KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; }
513         return kj::NEVER_DONE;
514       }
515       KJ_CASE_ONEOF(e, Ended) {
516         // No shortening possible. Pretend we never resolve so that calls continue to be routed
517         // to us forever.
518         return kj::NEVER_DONE;
519       }
520     }
521     KJ_UNREACHABLE;
522   }
523 
write(WriteContext context)524   kj::Promise<void> write(WriteContext context) override {
525     KJ_SWITCH_ONEOF(state) {
526       KJ_CASE_ONEOF(prober, kj::Own<PathProber>) {
527         return prober->whenReady().then([this, context]() mutable {
528           KJ_ASSERT(!state.is<kj::Own<PathProber>>());
529           return write(context);
530         });
531       }
532       KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) {
533         auto data = context.getParams().getBytes();
534         return kjStream->write(data.begin(), data.size());
535       }
536       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) {
537         auto params = context.getParams();
538         auto req = capnpStream.writeRequest(params.totalSize());
539         req.setBytes(params.getBytes());
540         return req.send();
541       }
542       KJ_CASE_ONEOF(b, Borrowed) {
543         KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; }
544         return kj::READY_NOW;
545       }
546       KJ_CASE_ONEOF(e, Ended) {
547         KJ_FAIL_REQUIRE("already called end()") { break; }
548         return kj::READY_NOW;
549       }
550     }
551     KJ_UNREACHABLE;
552   }
553 
end(EndContext context)554   kj::Promise<void> end(EndContext context) override {
555     KJ_SWITCH_ONEOF(state) {
556       KJ_CASE_ONEOF(prober, kj::Own<PathProber>) {
557         return prober->whenReady().then([this, context]() mutable {
558           KJ_ASSERT(!state.is<kj::Own<PathProber>>());
559           return end(context);
560         });
561       }
562       KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) {
563         // TODO(someday): When KJ adds a proper .end() call, use it here. For now, we must
564         //   drop the stream to close it.
565         state = Ended();
566         return kj::READY_NOW;
567       }
568       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) {
569         auto params = context.getParams();
570         auto req = capnpStream.endRequest(params.totalSize());
571         return context.tailCall(kj::mv(req));
572       }
573       KJ_CASE_ONEOF(b, Borrowed) {
574         KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; }
575         return kj::READY_NOW;
576       }
577       KJ_CASE_ONEOF(e, Ended) {
578         KJ_FAIL_REQUIRE("already called end()") { break; }
579         return kj::READY_NOW;
580       }
581     }
582     KJ_UNREACHABLE;
583   }
584 
getSubstream(GetSubstreamContext context)585   kj::Promise<void> getSubstream(GetSubstreamContext context) override {
586     KJ_SWITCH_ONEOF(state) {
587       KJ_CASE_ONEOF(prober, kj::Own<PathProber>) {
588         return prober->whenReady().then([this, context]() mutable {
589           KJ_ASSERT(!state.is<kj::Own<PathProber>>());
590           return getSubstream(context);
591         });
592       }
593       KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) {
594         auto params = context.getParams();
595         auto callback = params.getCallback();
596         uint64_t limit = params.getLimit();
597         context.releaseParams();
598 
599         auto results = context.initResults(MessageSize {2, 1});
600         results.setSubstream(factory.streamSet.add(kj::heap<SubstreamImpl>(
601             factory, *this, thisCap(), *kjStream, kj::mv(callback), kj::mv(limit))));
602         state = Borrowed { kj::mv(kjStream) };
603         return kj::READY_NOW;
604       }
605       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) {
606         auto params = context.getParams();
607         auto req = capnpStream.getSubstreamRequest(params.totalSize());
608         req.setCallback(params.getCallback());
609         req.setLimit(params.getLimit());
610         return context.tailCall(kj::mv(req));
611       }
612       KJ_CASE_ONEOF(b, Borrowed) {
613         KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; }
614         return kj::READY_NOW;
615       }
616       KJ_CASE_ONEOF(e, Ended) {
617         KJ_FAIL_REQUIRE("already called end()") { break; }
618         return kj::READY_NOW;
619       }
620     }
621     KJ_UNREACHABLE;
622   }
623 
624 private:
625   ByteStreamFactory& factory;
626 
627   struct Borrowed { kj::Own<kj::AsyncOutputStream> stream; };
628   struct Ended {};
629 
630   kj::OneOf<kj::Own<PathProber>, kj::Own<kj::AsyncOutputStream>,
631             capnp::ByteStream::Client, Borrowed, Ended> state;
632 
633   class SubstreamCallbackImpl final: public capnp::ByteStream::SubstreamCallback::Server {
634   public:
SubstreamCallbackImpl(ByteStreamFactory & factory,kj::Own<PathProber> pathProber,kj::Own<kj::PromiseFulfiller<uint64_t>> originalPumpfulfiller,uint64_t originalPumpLimit)635     SubstreamCallbackImpl(ByteStreamFactory& factory,
636                           kj::Own<PathProber> pathProber,
637                           kj::Own<kj::PromiseFulfiller<uint64_t>> originalPumpfulfiller,
638                           uint64_t originalPumpLimit)
639         : factory(factory),
640           pathProber(kj::mv(pathProber)),
641           originalPumpfulfiller(kj::mv(originalPumpfulfiller)),
642           originalPumpLimit(originalPumpLimit) {}
643 
~SubstreamCallbackImpl()644     ~SubstreamCallbackImpl() noexcept(false) {
645       if (!done) {
646         originalPumpfulfiller->reject(KJ_EXCEPTION(DISCONNECTED,
647             "stream disconnected because SubstreamCallbackImpl was never called back"));
648       }
649     }
650 
ended(EndedContext context)651     kj::Promise<void> ended(EndedContext context) override {
652       KJ_REQUIRE(!done);
653       uint64_t actual = context.getParams().getByteCount();
654       KJ_REQUIRE(actual <= originalPumpLimit);
655 
656       done = true;
657 
658       // EOF before pump completed. Signal a short pump.
659       originalPumpfulfiller->fulfill(context.getParams().getByteCount());
660 
661       // Give the original pump task a chance to finish up.
662       return pathProber->task.attach(kj::mv(pathProber));
663     }
664 
reachedLimit(ReachedLimitContext context)665     kj::Promise<void> reachedLimit(ReachedLimitContext context) override {
666       KJ_REQUIRE(!done);
667       done = true;
668 
669       // Allow the shortened stream to redirect back to our original underlying stream.
670       auto results = context.getResults(capnp::MessageSize { 4, 1 });
671       results.setNext(factory.streamSet.add(
672           kj::heap<CapnpToKjStreamAdapter>(factory, kj::mv(pathProber))));
673 
674       // The full pump completed. Note that it's important that we fulfill this after the
675       // PathProber has been attached to the new CapnpToKjStreamAdapter, which will have happened
676       // in CapnpToKjStreamAdapter's constructor, which calls pathProber->setNewParent().
677       originalPumpfulfiller->fulfill(kj::cp(originalPumpLimit));
678 
679       return kj::READY_NOW;
680     }
681 
682   private:
683     ByteStreamFactory& factory;
684     kj::Own<PathProber> pathProber;
685     kj::Own<kj::PromiseFulfiller<uint64_t>> originalPumpfulfiller;
686     uint64_t originalPumpLimit;
687     bool done = false;
688   };
689 };
690 
691 // =======================================================================================
692 
693 class ByteStreamFactory::KjToCapnpStreamAdapter final: public kj::AsyncOutputStream {
694 public:
KjToCapnpStreamAdapter(ByteStreamFactory & factory,capnp::ByteStream::Client innerParam)695   KjToCapnpStreamAdapter(ByteStreamFactory& factory, capnp::ByteStream::Client innerParam)
696       : factory(factory),
697         inner(kj::mv(innerParam)),
698         findShorterPathTask(findShorterPath(inner).fork()) {}
699 
~KjToCapnpStreamAdapter()700   ~KjToCapnpStreamAdapter() noexcept(false) {
701     // HACK: KJ streams are implicitly ended on destruction, but the RPC stream needs a call. We
702     //   use a detached promise for now, which is probably OK since capabilities are refcounted and
703     //   asynchronously destroyed anyway.
704     // TODO(cleanup): Fix this when KJ streads add an explicit end() method.
705     KJ_IF_MAYBE(o, optimized) {
706       o->directEnd();
707     } else {
708       inner.endRequest(MessageSize {2, 0}).send().detach([](kj::Exception&&){});
709     }
710   }
711 
write(const void * buffer,size_t size)712   kj::Promise<void> write(const void* buffer, size_t size) override {
713     KJ_SWITCH_ONEOF(getShortestPath()) {
714       KJ_CASE_ONEOF(promise, kj::Promise<void>) {
715         return promise.then([this,buffer,size]() {
716           return write(buffer, size);
717         });
718       }
719       KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) {
720         auto limit = kj::min(kjStream.limit, MAX_BYTES_PER_WRITE);
721         if (size <= limit) {
722           auto promise = kjStream.stream.write(buffer, size);
723           return promise.then([kjStream,size]() mutable {
724             kjStream.lender.returnStream(size);
725           });
726         } else {
727           auto promise = kjStream.stream.write(buffer, limit);
728           return promise.then([this,kjStream,buffer,size,limit]() mutable {
729             kjStream.lender.returnStream(limit);
730             return write(reinterpret_cast<const byte*>(buffer) + limit,
731                          size - limit);
732           });
733         }
734       }
735       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) {
736         if (size <= MAX_BYTES_PER_WRITE) {
737           auto req = capnpStream->writeRequest(MessageSize { 8 + size / sizeof(word), 0 });
738           req.setBytes(kj::arrayPtr(reinterpret_cast<const byte*>(buffer), size));
739           return req.send();
740         } else {
741           auto req = capnpStream->writeRequest(
742               MessageSize { 8 + MAX_BYTES_PER_WRITE / sizeof(word), 0 });
743           req.setBytes(kj::arrayPtr(reinterpret_cast<const byte*>(buffer), MAX_BYTES_PER_WRITE));
744           return req.send().then([this,buffer,size]() mutable {
745             return write(reinterpret_cast<const byte*>(buffer) + MAX_BYTES_PER_WRITE,
746                          size - MAX_BYTES_PER_WRITE);
747           });
748         }
749       }
750     }
751     KJ_UNREACHABLE;
752   }
753 
write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces)754   kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override {
755     KJ_SWITCH_ONEOF(getShortestPath()) {
756       KJ_CASE_ONEOF(promise, kj::Promise<void>) {
757         return promise.then([this,pieces]() {
758           return write(pieces);
759         });
760       }
761       KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) {
762         size_t size = 0;
763         for (auto& piece: pieces) { size += piece.size(); }
764         auto limit = kj::min(kjStream.limit, MAX_BYTES_PER_WRITE);
765         if (size <= limit) {
766           auto promise = kjStream.stream.write(pieces);
767           return promise.then([kjStream,size]() mutable {
768             kjStream.lender.returnStream(size);
769           });
770         } else {
771           // ughhhhhhhhhh, we need to split the pieces.
772           return splitAndWrite(pieces, kjStream.limit,
773               [kjStream,limit](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) mutable {
774             return kjStream.stream.write(pieces).then([kjStream,limit]() mutable {
775               kjStream.lender.returnStream(limit);
776             });
777           });
778         }
779       }
780       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) {
781         auto writePieces = [capnpStream](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) {
782           size_t size = 0;
783           for (auto& piece: pieces) size += piece.size();
784           auto req = capnpStream->writeRequest(MessageSize { 8 + size / sizeof(word), 0 });
785           auto out = req.initBytes(size);
786           byte* ptr = out.begin();
787           for (auto& piece: pieces) {
788             memcpy(ptr, piece.begin(), piece.size());
789             ptr += piece.size();
790           }
791           KJ_ASSERT(ptr == out.end());
792           return req.send();
793         };
794 
795         size_t size = 0;
796         for (auto& piece: pieces) size += piece.size();
797         if (size <= MAX_BYTES_PER_WRITE) {
798           return writePieces(pieces);
799         } else {
800           // ughhhhhhhhhh, we need to split the pieces.
801           return splitAndWrite(pieces, MAX_BYTES_PER_WRITE, writePieces);
802         }
803       }
804     }
805     KJ_UNREACHABLE;
806   }
807 
tryPumpFrom(kj::AsyncInputStream & input,uint64_t amount=kj::maxValue)808   kj::Maybe<kj::Promise<uint64_t>> tryPumpFrom(
809       kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override {
810     KJ_IF_MAYBE(rpc, kj::dynamicDowncastIfAvailable<CapnpToKjStreamAdapter::PathProber>(input)) {
811       // Oh interesting, it turns we're hosting an incoming ByteStream which is pumping to this
812       // outgoing ByteStream. We can let the Cap'n Proto RPC layer know that it can shorten the
813       // path from one to the other.
814       return rpc->pumpToShorterPath(inner, amount);
815     } else {
816       return pumpLoop(input, 0, amount);
817     }
818   }
819 
whenWriteDisconnected()820   kj::Promise<void> whenWriteDisconnected() override {
821     return findShorterPathTask.addBranch();
822   }
823 
824 private:
825   ByteStreamFactory& factory;
826   capnp::ByteStream::Client inner;
827   kj::Maybe<StreamServerBase&> optimized;
828 
829   kj::ForkedPromise<void> findShorterPathTask;
830   // This serves two purposes:
831   // 1. Waits for the capability to resolve (if it is a promise), and then shortens the path if
832   //    possible.
833   // 2. Implements whenWriteDisconnected().
834 
findShorterPath(capnp::ByteStream::Client & capnpClient)835   kj::Promise<void> findShorterPath(capnp::ByteStream::Client& capnpClient) {
836     // If the capnp stream turns out to resolve back to this process, shorten the path.
837     // Also, implement whenWriteDisconnected() based on this.
838     return factory.streamSet.getLocalServer(capnpClient)
839         .then([this](kj::Maybe<capnp::ByteStream::Server&> server) -> kj::Promise<void> {
840       KJ_IF_MAYBE(s, server) {
841         // Yay, we discovered that the ByteStream actually points back to a local KJ stream.
842         // We can use this to shorten the path by skipping the RPC machinery.
843         return findShorterPath(kj::downcast<StreamServerBase>(*s));
844       } else {
845         // The capability is fully-resolved. This suggests that the remote implementation is
846         // NOT a CapnpToKjStreamAdapter at all, because CapnpToKjStreamAdapter is designed to
847         // always look like a promise. It's some other implementation that doesn't present
848         // itself as a promise. We have no way to detect when it is disconnected.
849         return kj::NEVER_DONE;
850       }
851     }, [](kj::Exception&& e) -> kj::Promise<void> {
852       // getLocalServer() thrown when the capability is a promise cap that rejects. We can
853       // use this to implement whenWriteDisconnected().
854       //
855       // (Note that because this exception handler is passed to the .then(), it does NOT catch
856       // eoxceptions thrown by the success handler immediately above it. This handler will ONLY
857       // catch exceptions from getLocalServer() itself.)
858       return kj::READY_NOW;
859     });
860   }
861 
findShorterPath(StreamServerBase & capnpServer)862   kj::Promise<void> findShorterPath(StreamServerBase& capnpServer) {
863     // We found a shorter path back to this process. Record it.
864     optimized = capnpServer;
865 
866     KJ_SWITCH_ONEOF(capnpServer.getShortestPath()) {
867       KJ_CASE_ONEOF(promise, kj::Promise<void>) {
868         return promise.then([this,&capnpServer]() {
869           return findShorterPath(capnpServer);
870         });
871       }
872       KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) {
873         // The ByteStream::Server wraps a regular KJ stream that does not wrap another capnp
874         // stream.
875         if (kjStream.limit < (uint64_t)kj::maxValue / 2) {
876           // But it isn't wrapping that stream forever. Eventually it plans to redirect back to
877           // some other stream. So, let's wait for that, and possibly shorten again.
878           kjStream.lender.returnStream(0);
879           return KJ_ASSERT_NONNULL(capnpServer.shortenPath())
880               .then([this, &capnpServer](auto&&) {
881             return findShorterPath(capnpServer);
882           });
883         } else {
884           // This KJ stream is (effectively) the permanent endpoint. We can't get any shorter
885           // from here. All we want to do now is watch for disconnect.
886           auto promise = kjStream.stream.whenWriteDisconnected();
887           kjStream.lender.returnStream(0);
888           return promise;
889         }
890       }
891       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) {
892         return findShorterPath(*capnpStream);
893       }
894     }
895     KJ_UNREACHABLE;
896   }
897 
getShortestPath()898   StreamServerBase::ShortestPath getShortestPath() {
899     KJ_IF_MAYBE(o, optimized) {
900       return o->getShortestPath();
901     } else {
902       return &inner;
903     }
904   }
905 
pumpLoop(kj::AsyncInputStream & input,uint64_t completed,uint64_t remaining)906   kj::Promise<uint64_t> pumpLoop(kj::AsyncInputStream& input,
907                                  uint64_t completed, uint64_t remaining) {
908     if (remaining == 0) return completed;
909 
910     KJ_SWITCH_ONEOF(getShortestPath()) {
911       KJ_CASE_ONEOF(promise, kj::Promise<void>) {
912         return promise.then([this,&input,completed,remaining]() {
913           return pumpLoop(input,completed,remaining);
914         });
915       }
916       KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) {
917         // Oh hell yes, this capability actually points back to a stream in our own thread. We can
918         // stop sending RPCs and just pump directly.
919 
920         if (remaining <= kjStream.limit) {
921           return input.pumpTo(kjStream.stream, remaining)
922               .then([kjStream,completed](uint64_t actual) {
923             kjStream.lender.returnStream(actual);
924             return actual + completed;
925           });
926         } else {
927           auto promise = input.pumpTo(kjStream.stream, kjStream.limit);
928           return promise.then([this,&input,completed,remaining,kjStream]
929                               (uint64_t actual) mutable -> kj::Promise<uint64_t> {
930             kjStream.lender.returnStream(actual);
931             if (actual < kjStream.limit) {
932               // EOF reached.
933               return completed + actual;
934             } else {
935               return pumpLoop(input, completed + actual, remaining - actual);
936             }
937           });
938         }
939       }
940       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) {
941         // Pumping from some other kind of steram. Optimize the pump by reading from the input
942         // directly into outgoing RPC messages.
943         size_t size = kj::min(remaining, 8192);
944         auto req = capnpStream->writeRequest(MessageSize { 8 + size / sizeof(word) });
945 
946         auto orphanage = Orphanage::getForMessageContaining(
947             capnp::ByteStream::WriteParams::Builder(req));
948 
949         auto buffer = orphanage.newOrphan<Data>(size);
950 
951         struct WriteRequestAndBuffer {
952           // The order of construction/destruction of lambda captures is unspecified, but we care
953           // about ordering between these two things that we want to capture, so... we need a
954           // struct.
955           StreamingRequest<capnp::ByteStream::WriteParams> request;
956           Orphan<Data> buffer;  // points into `request`...
957         };
958 
959         WriteRequestAndBuffer wrab = { kj::mv(req), kj::mv(buffer) };
960 
961         return input.tryRead(wrab.buffer.get().begin(), 1, size)
962             .then([this, &input, completed, remaining, size, wrab = kj::mv(wrab)]
963                   (size_t actual) mutable -> kj::Promise<uint64_t> {
964           if (actual == 0) {
965             return completed;
966           } if (actual < size) {
967             wrab.buffer.truncate(actual);
968           }
969 
970           wrab.request.adoptBytes(kj::mv(wrab.buffer));
971           return wrab.request.send()
972               .then([this, &input, completed, remaining, actual]() {
973             return pumpLoop(input, completed + actual, remaining - actual);
974           });
975         });
976       }
977     }
978     KJ_UNREACHABLE;
979   }
980 
981   template <typename WritePieces>
splitAndWrite(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces,size_t limit,WritePieces && writeFirstPieces)982   kj::Promise<void> splitAndWrite(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces,
983                                   size_t limit, WritePieces&& writeFirstPieces) {
984     size_t splitByte = limit;
985     size_t splitPiece = 0;
986     while (pieces[splitPiece].size() <= splitByte) {
987       splitByte -= pieces[splitPiece].size();
988       ++splitPiece;
989     }
990 
991     if (splitByte == 0) {
992       // Oh thank god, the split is between two pieces.
993       auto rest = pieces.slice(splitPiece, pieces.size());
994       return writeFirstPieces(pieces.slice(0, splitPiece))
995           .then([this,rest]() mutable {
996         return write(rest);
997       });
998     } else {
999       // FUUUUUUUU---- we need to split one of the pieces in two.
1000       auto left = kj::heapArray<kj::ArrayPtr<const byte>>(splitPiece + 1);
1001       auto right = kj::heapArray<kj::ArrayPtr<const byte>>(pieces.size() - splitPiece);
1002       for (auto i: kj::zeroTo(splitPiece)) {
1003         left[i] = pieces[i];
1004       }
1005       for (auto i: kj::zeroTo(right.size())) {
1006         right[i] = pieces[splitPiece + i];
1007       }
1008       left.back() = pieces[splitPiece].slice(0, splitByte);
1009       right.front() = pieces[splitPiece].slice(splitByte, pieces[splitPiece].size());
1010 
1011       return writeFirstPieces(left).attach(kj::mv(left))
1012           .then([this,right=kj::mv(right)]() mutable {
1013         return write(right).attach(kj::mv(right));
1014       });
1015     }
1016   }
1017 };
1018 
1019 // =======================================================================================
1020 
kjToCapnp(kj::Own<kj::AsyncOutputStream> kjStream)1021 capnp::ByteStream::Client ByteStreamFactory::kjToCapnp(kj::Own<kj::AsyncOutputStream> kjStream) {
1022   return streamSet.add(kj::heap<CapnpToKjStreamAdapter>(*this, kj::mv(kjStream)));
1023 }
1024 
capnpToKj(capnp::ByteStream::Client capnpStream)1025 kj::Own<kj::AsyncOutputStream> ByteStreamFactory::capnpToKj(capnp::ByteStream::Client capnpStream) {
1026   return kj::heap<KjToCapnpStreamAdapter>(*this, kj::mv(capnpStream));
1027 }
1028 
1029 }  // namespace capnp
1030