1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <folly/io/async/AsyncSSLSocket.h>
18 
19 #include <folly/futures/Promise.h>
20 #include <folly/init/Init.h>
21 #include <folly/io/async/EventBase.h>
22 #include <folly/io/async/SSLContext.h>
23 #include <folly/io/async/ScopedEventBaseThread.h>
24 #include <folly/io/async/test/AsyncSSLSocketTest.h>
25 #include <folly/portability/GTest.h>
26 #include <folly/portability/PThread.h>
27 #include <folly/ssl/Init.h>
28 
29 using std::cerr;
30 using std::endl;
31 
32 namespace folly {
33 
34 struct EvbAndContext {
EvbAndContextfolly::EvbAndContext35   EvbAndContext() {
36     ctx_.reset(new SSLContext());
37     ctx_->setOptions(SSL_OP_NO_TICKET);
38     ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
39   }
40 
createSocketfolly::EvbAndContext41   std::shared_ptr<AsyncSSLSocket> createSocket() {
42     return AsyncSSLSocket::newSocket(ctx_, getEventBase());
43   }
44 
getEventBasefolly::EvbAndContext45   EventBase* getEventBase() { return evb_.getEventBase(); }
46 
attachfolly::EvbAndContext47   void attach(AsyncSSLSocket& socket) {
48     socket.attachEventBase(getEventBase());
49     socket.attachSSLContext(ctx_);
50   }
51 
52   folly::ScopedEventBaseThread evb_;
53   std::shared_ptr<SSLContext> ctx_;
54 };
55 
56 class AttachDetachClient : public AsyncSocket::ConnectCallback,
57                            public AsyncTransport::WriteCallback,
58                            public AsyncTransport::ReadCallback {
59  private:
60   // two threads here - we'll create the socket in one, connect
61   // in the other, and then read/write in the initial one
62   EvbAndContext t1_;
63   EvbAndContext t2_;
64   std::shared_ptr<AsyncSSLSocket> sslSocket_;
65   folly::SocketAddress address_;
66   char buf_[128];
67   char readbuf_[128];
68   uint32_t bytesRead_;
69   // promise to fulfill when done
70   folly::Promise<bool> promise_;
71 
detach()72   void detach() {
73     sslSocket_->detachEventBase();
74     sslSocket_->detachSSLContext();
75   }
76 
77  public:
AttachDetachClient(const folly::SocketAddress & address)78   explicit AttachDetachClient(const folly::SocketAddress& address)
79       : address_(address), bytesRead_(0) {}
80 
getFuture()81   Future<bool> getFuture() { return promise_.getFuture(); }
82 
connect()83   void connect() {
84     // create in one and then move to another
85     auto t1Evb = t1_.getEventBase();
86     t1Evb->runInEventBaseThread([this] {
87       sslSocket_ = t1_.createSocket();
88       // ensure we can detach and reattach the context before connecting
89       for (int i = 0; i < 1000; ++i) {
90         sslSocket_->detachSSLContext();
91         sslSocket_->attachSSLContext(t1_.ctx_);
92       }
93       // detach from t1 and connect in t2
94       detach();
95       auto t2Evb = t2_.getEventBase();
96       t2Evb->runInEventBaseThread([this] {
97         t2_.attach(*sslSocket_);
98         sslSocket_->connect(this, address_);
99       });
100     });
101   }
102 
connectSuccess()103   void connectSuccess() noexcept override {
104     auto t2Evb = t2_.getEventBase();
105     EXPECT_TRUE(t2Evb->isInEventBaseThread());
106     cerr << "client SSL socket connected" << endl;
107     for (int i = 0; i < 1000; ++i) {
108       sslSocket_->detachSSLContext();
109       sslSocket_->attachSSLContext(t2_.ctx_);
110     }
111 
112     // detach from t2 and then read/write in t1
113     t2Evb->runInEventBaseThread([this] {
114       detach();
115       auto t1Evb = t1_.getEventBase();
116       t1Evb->runInEventBaseThread([this] {
117         t1_.attach(*sslSocket_);
118         sslSocket_->write(this, buf_, sizeof(buf_));
119         sslSocket_->setReadCB(this);
120         memset(readbuf_, 'b', sizeof(readbuf_));
121         bytesRead_ = 0;
122       });
123     });
124   }
125 
connectErr(const AsyncSocketException & ex)126   void connectErr(const AsyncSocketException& ex) noexcept override {
127     cerr << "AttachDetachClient::connectError: " << ex.what() << endl;
128     sslSocket_.reset();
129   }
130 
writeSuccess()131   void writeSuccess() noexcept override {
132     cerr << "client write success" << endl;
133   }
134 
writeErr(size_t,const AsyncSocketException & ex)135   void writeErr(
136       size_t /* bytesWritten */,
137       const AsyncSocketException& ex) noexcept override {
138     cerr << "client writeError: " << ex.what() << endl;
139   }
140 
getReadBuffer(void ** bufReturn,size_t * lenReturn)141   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
142     *bufReturn = readbuf_ + bytesRead_;
143     *lenReturn = sizeof(readbuf_) - bytesRead_;
144   }
readEOF()145   void readEOF() noexcept override { cerr << "client readEOF" << endl; }
146 
readErr(const AsyncSocketException & ex)147   void readErr(const AsyncSocketException& ex) noexcept override {
148     cerr << "client readError: " << ex.what() << endl;
149     promise_.setException(ex);
150   }
151 
readDataAvailable(size_t len)152   void readDataAvailable(size_t len) noexcept override {
153     EXPECT_TRUE(t1_.getEventBase()->isInEventBaseThread());
154     EXPECT_EQ(sslSocket_->getEventBase(), t1_.getEventBase());
155     cerr << "client read data: " << len << endl;
156     bytesRead_ += len;
157     if (len == sizeof(buf_)) {
158       EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
159       sslSocket_->closeNow();
160       sslSocket_.reset();
161       promise_.setValue(true);
162     }
163   }
164 };
165 
166 /**
167  * Test passing contexts between threads
168  */
TEST(AsyncSSLSocketTest2,AttachDetachSSLContext)169 TEST(AsyncSSLSocketTest2, AttachDetachSSLContext) {
170   // Start listening on a local port
171   WriteCallbackBase writeCallback;
172   ReadCallback readCallback(&writeCallback);
173   HandshakeCallback handshakeCallback(&readCallback);
174   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
175   TestSSLServer server(&acceptCallback);
176 
177   std::shared_ptr<AttachDetachClient> client(
178       new AttachDetachClient(server.getAddress()));
179 
180   auto f = client->getFuture();
181   client->connect();
182   EXPECT_TRUE(std::move(f).within(std::chrono::seconds(3)).get());
183 }
184 
185 class ConnectClient : public AsyncSocket::ConnectCallback {
186  public:
187   ConnectClient() = default;
188 
getFuture()189   Future<bool> getFuture() { return promise_.getFuture(); }
190 
connect(const folly::SocketAddress & addr)191   void connect(const folly::SocketAddress& addr) {
192     t1_.getEventBase()->runInEventBaseThread([&] {
193       socket_ = t1_.createSocket();
194       socket_->connect(this, addr);
195     });
196   }
197 
connectSuccess()198   void connectSuccess() noexcept override {
199     socket_.reset();
200     promise_.setValue(true);
201   }
202 
connectErr(const AsyncSocketException &)203   void connectErr(const AsyncSocketException& /* ex */) noexcept override {
204     socket_.reset();
205     promise_.setValue(false);
206   }
207 
setCtx(std::shared_ptr<SSLContext> ctx)208   void setCtx(std::shared_ptr<SSLContext> ctx) { t1_.ctx_ = ctx; }
209 
210  private:
211   EvbAndContext t1_;
212   // promise to fulfill when done with a value of true if connect succeeded
213   folly::Promise<bool> promise_;
214   std::shared_ptr<AsyncSSLSocket> socket_;
215 };
216 
217 class NoopReadCallback : public ReadCallbackBase {
218  public:
NoopReadCallback()219   NoopReadCallback() : ReadCallbackBase(nullptr) { state = STATE_SUCCEEDED; }
220 
getReadBuffer(void ** buf,size_t * lenReturn)221   void getReadBuffer(void** buf, size_t* lenReturn) override {
222     *buf = &buffer_;
223     *lenReturn = 1;
224   }
readDataAvailable(size_t)225   void readDataAvailable(size_t) noexcept override {}
226 
227   uint8_t buffer_{0};
228 };
229 
TEST(AsyncSSLSocketTest2,TestTLS12DefaultClient)230 TEST(AsyncSSLSocketTest2, TestTLS12DefaultClient) {
231   // Start listening on a local port
232   NoopReadCallback readCallback;
233   HandshakeCallback handshakeCallback(&readCallback);
234   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
235   auto ctx = std::make_shared<SSLContext>(SSLContext::TLSv1_2);
236   TestSSLServer server(&acceptCallback, ctx);
237   server.loadTestCerts();
238 
239   // create a default client
240   auto c1 = std::make_unique<ConnectClient>();
241   auto f1 = c1->getFuture();
242   c1->connect(server.getAddress());
243   EXPECT_TRUE(std::move(f1).within(std::chrono::seconds(3)).get());
244 }
245 
TEST(AsyncSSLSocketTest2,TestTLS12BadClient)246 TEST(AsyncSSLSocketTest2, TestTLS12BadClient) {
247   // Start listening on a local port
248   NoopReadCallback readCallback;
249   HandshakeCallback handshakeCallback(
250       &readCallback, HandshakeCallback::EXPECT_ERROR);
251   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
252   auto ctx = std::make_shared<SSLContext>(SSLContext::TLSv1_2);
253   TestSSLServer server(&acceptCallback, ctx);
254   server.loadTestCerts();
255 
256   // create a client that doesn't speak TLS 1.2
257   auto c2 = std::make_unique<ConnectClient>();
258   auto clientCtx = std::make_shared<SSLContext>();
259   clientCtx->setOptions(SSL_OP_NO_TLSv1_2);
260   c2->setCtx(clientCtx);
261   auto f2 = c2->getFuture();
262   c2->connect(server.getAddress());
263   EXPECT_FALSE(std::move(f2).within(std::chrono::seconds(3)).get());
264 }
265 
266 } // namespace folly
267 
main(int argc,char * argv[])268 int main(int argc, char* argv[]) {
269   folly::ssl::init();
270 #ifdef SIGPIPE
271   signal(SIGPIPE, SIG_IGN);
272 #endif
273   testing::InitGoogleTest(&argc, argv);
274   folly::init(&argc, &argv);
275   return RUN_ALL_TESTS();
276   OPENSSL_cleanup();
277 }
278