1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <folly/ExceptionString.h>
20 #include <folly/io/Cursor.h>
21 #include <folly/io/async/DecoratedAsyncTransportWrapper.h>
22 #include <thrift/lib/cpp/server/TServerObserver.h>
23 #include <thrift/lib/cpp2/server/Cpp2Worker.h>
24 #include <thrift/lib/cpp2/server/peeking/TLSHelper.h>
25 #include <wangle/acceptor/Acceptor.h>
26 #include <wangle/acceptor/ManagedConnection.h>
27 #include <wangle/acceptor/SocketPeeker.h>
28 
29 namespace apache {
30 namespace thrift {
31 
32 /**
33  * The number of bytes that will be read from the socket.
34  * TLSHelper currently needs the most bytes. Thus, it's cap
35  * it up at the amount that TLSHelper needs.
36  */
37 constexpr uint8_t kPeekBytes = 13;
38 
39 /**
40  * A manager that rejects or accepts connections based on critera
41  * added by helper functions. This is useful for cases where
42  * clients might be sending different types of protocols
43  * over plaintext and it's up to the Acceptor to determine
44  * what kind of protocol they are talking to route to the
45  * appropriate handlers.
46  */
47 class PeekingManagerBase : public wangle::ManagedConnection {
48  public:
PeekingManagerBase(std::shared_ptr<apache::thrift::Cpp2Worker> acceptor,const folly::SocketAddress & clientAddr,wangle::TransportInfo tinfo,apache::thrift::ThriftServer * server)49   PeekingManagerBase(
50       std::shared_ptr<apache::thrift::Cpp2Worker> acceptor,
51       const folly::SocketAddress& clientAddr,
52       wangle::TransportInfo tinfo,
53       apache::thrift::ThriftServer* server)
54       : acceptor_(acceptor),
55         clientAddr_(clientAddr),
56         tinfo_(std::move(tinfo)),
57         server_(server) {
58     acceptor_->getConnectionManager()->addConnection(this, true);
59   }
60 
timeoutExpired()61   void timeoutExpired() noexcept override { dropConnection(); }
62 
63   void dropConnection(const std::string& /* errorMsg */ = "") override {
64     acceptor_->getConnectionManager()->removeConnection(this);
65     destroy();
66   }
67 
describe(std::ostream & os)68   void describe(std::ostream& os) const override {
69     os << "Peeking the socket " << clientAddr_;
70   }
71 
isBusy()72   bool isBusy() const override { return true; }
73 
notifyPendingShutdown()74   void notifyPendingShutdown() override {}
75 
closeWhenIdle()76   void closeWhenIdle() override {}
77 
dumpConnectionState(uint8_t)78   void dumpConnectionState(uint8_t /* loglevel */) override {}
79 
80  protected:
81   const std::shared_ptr<apache::thrift::Cpp2Worker> acceptor_;
82   const folly::SocketAddress clientAddr_;
83   wangle::TransportInfo tinfo_;
84   ThriftServer* const server_;
85 };
86 
87 class CheckTLSPeekingManager : public PeekingManagerBase,
88                                public wangle::SocketPeeker::Callback {
89  public:
CheckTLSPeekingManager(std::shared_ptr<apache::thrift::Cpp2Worker> acceptor,const folly::SocketAddress & clientAddr,wangle::TransportInfo tinfo,apache::thrift::ThriftServer * server,folly::AsyncSocket::UniquePtr socket,std::shared_ptr<apache::thrift::server::TServerObserver> obs)90   CheckTLSPeekingManager(
91       std::shared_ptr<apache::thrift::Cpp2Worker> acceptor,
92       const folly::SocketAddress& clientAddr,
93       wangle::TransportInfo tinfo,
94       apache::thrift::ThriftServer* server,
95       folly::AsyncSocket::UniquePtr socket,
96       std::shared_ptr<apache::thrift::server::TServerObserver> obs)
97       : PeekingManagerBase(
98             std::move(acceptor), clientAddr, std::move(tinfo), server),
99         socket_(std::move(socket)),
100         observer_(std::move(obs)),
101         peeker_(new wangle::SocketPeeker(*socket_, this, kPeekBytes)) {
102     peeker_->start();
103   }
104 
~CheckTLSPeekingManager()105   ~CheckTLSPeekingManager() override {
106     if (socket_) {
107       socket_->closeNow();
108     }
109   }
110 
peekSuccess(std::vector<uint8_t> peekBytes)111   void peekSuccess(std::vector<uint8_t> peekBytes) noexcept override {
112     folly::DelayedDestruction::DestructorGuard dg(this);
113     dropConnection();
114     if (TLSHelper::looksLikeTLS(peekBytes)) {
115       LOG(ERROR) << "Received SSL connection on non SSL port";
116       sendPlaintextTLSAlert(peekBytes);
117       if (observer_) {
118         observer_->protocolError();
119       }
120       return;
121     }
122     acceptor_->connectionReady(
123         std::move(socket_),
124         std::move(clientAddr_),
125         {},
126         SecureTransportType::NONE,
127         tinfo_);
128   }
129 
sendPlaintextTLSAlert(const std::vector<uint8_t> & peekBytes)130   void sendPlaintextTLSAlert(const std::vector<uint8_t>& peekBytes) {
131     uint8_t major = peekBytes[1];
132     uint8_t minor = peekBytes[2];
133     auto alert = TLSHelper::getPlaintextAlert(
134         major, minor, TLSHelper::Alert::UNEXPECTED_MESSAGE);
135     socket_->writeChain(nullptr, std::move(alert));
136   }
137 
peekError(const folly::AsyncSocketException &)138   void peekError(const folly::AsyncSocketException&) noexcept override {
139     dropConnection();
140   }
141 
142   void dropConnection(const std::string& errorMsg = "") override {
143     folly::DelayedDestruction::DestructorGuard dg(this);
144     peeker_ = nullptr;
145     PeekingManagerBase::dropConnection(errorMsg);
146   }
147 
148  private:
149   folly::AsyncSocket::UniquePtr socket_;
150   std::shared_ptr<apache::thrift::server::TServerObserver> observer_;
151   typename wangle::SocketPeeker::UniquePtr peeker_;
152 };
153 
154 class PreReceivedDataAsyncTransportWrapper
155     : public folly::DecoratedAsyncTransportWrapper<folly::AsyncTransport> {
156   using Base = folly::DecoratedAsyncTransportWrapper<folly::AsyncTransport>;
157 
158  public:
159   using UniquePtr = std::unique_ptr<AsyncTransport, Destructor>;
160 
create(folly::AsyncTransport::UniquePtr socket,std::vector<uint8_t> preReceivedData)161   static UniquePtr create(
162       folly::AsyncTransport::UniquePtr socket,
163       std::vector<uint8_t> preReceivedData) {
164     DCHECK(!socket->getReadCallback());
165     return UniquePtr(new PreReceivedDataAsyncTransportWrapper(
166         std::move(socket), std::move(preReceivedData)));
167   }
168 
getReadCallback()169   ReadCallback* getReadCallback() const override { return readCallback_; }
170 
setReadCB(folly::AsyncTransport::ReadCallback * callback)171   void setReadCB(folly::AsyncTransport::ReadCallback* callback) override {
172     folly::DelayedDestruction::DestructorGuard dg(this);
173     readCallback_ = callback;
174     if (preReceivedData_) {
175       if (!readCallback_) {
176         return;
177       }
178       const auto preReceivedData = std::exchange(preReceivedData_, {});
179       void* buf;
180       size_t bufSize;
181       callback->getReadBuffer(&buf, &bufSize);
182       CHECK(callback == readCallback_);
183       CHECK(bufSize >= preReceivedData->size());
184       std::memcpy(buf, preReceivedData->data(), preReceivedData->size());
185       callback->readDataAvailable(preReceivedData->size());
186     }
187     if (readCallback_ == callback) {
188       Base::setReadCB(callback);
189     }
190   }
191 
192  private:
PreReceivedDataAsyncTransportWrapper(folly::AsyncTransport::UniquePtr socket,std::vector<uint8_t> preReceivedData)193   PreReceivedDataAsyncTransportWrapper(
194       folly::AsyncTransport::UniquePtr socket,
195       std::vector<uint8_t> preReceivedData)
196       : Base(std::move(socket)),
197         preReceivedData_(
198             preReceivedData.size() ? std::make_unique<std::vector<uint8_t>>(
199                                          std::move(preReceivedData))
200                                    : std::unique_ptr<std::vector<uint8_t>>()) {}
201 
202   std::unique_ptr<std::vector<uint8_t>> preReceivedData_;
203   folly::AsyncTransport::ReadCallback* readCallback_{};
204 };
205 
206 class TransportPeekingManager : public PeekingManagerBase,
207                                 public wangle::SocketPeeker::Callback {
208  public:
TransportPeekingManager(std::shared_ptr<apache::thrift::Cpp2Worker> acceptor,const folly::SocketAddress & clientAddr,wangle::TransportInfo tinfo,apache::thrift::ThriftServer * server,folly::AsyncTransport::UniquePtr socket)209   TransportPeekingManager(
210       std::shared_ptr<apache::thrift::Cpp2Worker> acceptor,
211       const folly::SocketAddress& clientAddr,
212       wangle::TransportInfo tinfo,
213       apache::thrift::ThriftServer* server,
214       folly::AsyncTransport::UniquePtr socket)
215       : PeekingManagerBase(
216             std::move(acceptor), clientAddr, std::move(tinfo), server),
217         socket_(std::move(socket)),
218         peeker_(new wangle::TransportPeeker(*socket_, this, kPeekBytes)) {
219     peeker_->start();
220   }
221 
~TransportPeekingManager()222   ~TransportPeekingManager() override {
223     if (socket_) {
224       socket_->closeNow();
225     }
226   }
227 
peekSuccess(std::vector<uint8_t> peekBytes)228   void peekSuccess(std::vector<uint8_t> peekBytes) noexcept override {
229     folly::DelayedDestruction::DestructorGuard dg(this);
230     dropConnection();
231 
232     // This is possible when acceptor is stopped between taking a new
233     // connection and calling back to this function.
234     if (acceptor_->isStopping()) {
235       return;
236     }
237 
238     auto transport = PreReceivedDataAsyncTransportWrapper::create(
239         std::move(socket_), peekBytes);
240 
241     try {
242       // Check for new transports
243       bool acceptedHandler = false;
244       for (auto const& handler : *server_->getRoutingHandlers()) {
245         if (handler->canAcceptConnection(peekBytes)) {
246           handler->handleConnection(
247               acceptor_->getConnectionManager(),
248               std::move(transport),
249               &clientAddr_,
250               tinfo_,
251               acceptor_);
252           acceptedHandler = true;
253           break;
254         }
255       }
256 
257       // Default to Header Transport
258       if (!acceptedHandler) {
259         acceptor_->handleHeader(std::move(transport), &clientAddr_, tinfo_);
260       }
261     } catch (...) {
262       LOG(ERROR) << __func__ << " failed, dropping connection: "
263                  << folly::exceptionStr(std::current_exception());
264     }
265   }
266 
peekError(const folly::AsyncSocketException &)267   void peekError(const folly::AsyncSocketException&) noexcept override {
268     dropConnection();
269   }
270 
271   void dropConnection(const std::string& errorMsg = "") override {
272     folly::DelayedDestruction::DestructorGuard dg(this);
273     peeker_ = nullptr;
274     PeekingManagerBase::dropConnection(errorMsg);
275   }
276 
277  private:
278   folly::AsyncTransport::UniquePtr socket_;
279   typename wangle::TransportPeeker::UniquePtr peeker_;
280 };
281 
282 } // namespace thrift
283 } // namespace apache
284