1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <folly/io/async/AsyncSSLSocket.h>
18
19 #include <folly/futures/Promise.h>
20 #include <folly/init/Init.h>
21 #include <folly/io/async/EventBase.h>
22 #include <folly/io/async/SSLContext.h>
23 #include <folly/io/async/ScopedEventBaseThread.h>
24 #include <folly/io/async/test/AsyncSSLSocketTest.h>
25 #include <folly/portability/GTest.h>
26 #include <folly/portability/PThread.h>
27 #include <folly/ssl/Init.h>
28
29 using std::cerr;
30 using std::endl;
31
32 namespace folly {
33
34 struct EvbAndContext {
EvbAndContextfolly::EvbAndContext35 EvbAndContext() {
36 ctx_.reset(new SSLContext());
37 ctx_->setOptions(SSL_OP_NO_TICKET);
38 ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
39 }
40
createSocketfolly::EvbAndContext41 std::shared_ptr<AsyncSSLSocket> createSocket() {
42 return AsyncSSLSocket::newSocket(ctx_, getEventBase());
43 }
44
getEventBasefolly::EvbAndContext45 EventBase* getEventBase() { return evb_.getEventBase(); }
46
attachfolly::EvbAndContext47 void attach(AsyncSSLSocket& socket) {
48 socket.attachEventBase(getEventBase());
49 socket.attachSSLContext(ctx_);
50 }
51
52 folly::ScopedEventBaseThread evb_;
53 std::shared_ptr<SSLContext> ctx_;
54 };
55
56 class AttachDetachClient : public AsyncSocket::ConnectCallback,
57 public AsyncTransport::WriteCallback,
58 public AsyncTransport::ReadCallback {
59 private:
60 // two threads here - we'll create the socket in one, connect
61 // in the other, and then read/write in the initial one
62 EvbAndContext t1_;
63 EvbAndContext t2_;
64 std::shared_ptr<AsyncSSLSocket> sslSocket_;
65 folly::SocketAddress address_;
66 char buf_[128];
67 char readbuf_[128];
68 uint32_t bytesRead_;
69 // promise to fulfill when done
70 folly::Promise<bool> promise_;
71
detach()72 void detach() {
73 sslSocket_->detachEventBase();
74 sslSocket_->detachSSLContext();
75 }
76
77 public:
AttachDetachClient(const folly::SocketAddress & address)78 explicit AttachDetachClient(const folly::SocketAddress& address)
79 : address_(address), bytesRead_(0) {}
80
getFuture()81 Future<bool> getFuture() { return promise_.getFuture(); }
82
connect()83 void connect() {
84 // create in one and then move to another
85 auto t1Evb = t1_.getEventBase();
86 t1Evb->runInEventBaseThread([this] {
87 sslSocket_ = t1_.createSocket();
88 // ensure we can detach and reattach the context before connecting
89 for (int i = 0; i < 1000; ++i) {
90 sslSocket_->detachSSLContext();
91 sslSocket_->attachSSLContext(t1_.ctx_);
92 }
93 // detach from t1 and connect in t2
94 detach();
95 auto t2Evb = t2_.getEventBase();
96 t2Evb->runInEventBaseThread([this] {
97 t2_.attach(*sslSocket_);
98 sslSocket_->connect(this, address_);
99 });
100 });
101 }
102
connectSuccess()103 void connectSuccess() noexcept override {
104 auto t2Evb = t2_.getEventBase();
105 EXPECT_TRUE(t2Evb->isInEventBaseThread());
106 cerr << "client SSL socket connected" << endl;
107 for (int i = 0; i < 1000; ++i) {
108 sslSocket_->detachSSLContext();
109 sslSocket_->attachSSLContext(t2_.ctx_);
110 }
111
112 // detach from t2 and then read/write in t1
113 t2Evb->runInEventBaseThread([this] {
114 detach();
115 auto t1Evb = t1_.getEventBase();
116 t1Evb->runInEventBaseThread([this] {
117 t1_.attach(*sslSocket_);
118 sslSocket_->write(this, buf_, sizeof(buf_));
119 sslSocket_->setReadCB(this);
120 memset(readbuf_, 'b', sizeof(readbuf_));
121 bytesRead_ = 0;
122 });
123 });
124 }
125
connectErr(const AsyncSocketException & ex)126 void connectErr(const AsyncSocketException& ex) noexcept override {
127 cerr << "AttachDetachClient::connectError: " << ex.what() << endl;
128 sslSocket_.reset();
129 }
130
writeSuccess()131 void writeSuccess() noexcept override {
132 cerr << "client write success" << endl;
133 }
134
writeErr(size_t,const AsyncSocketException & ex)135 void writeErr(
136 size_t /* bytesWritten */,
137 const AsyncSocketException& ex) noexcept override {
138 cerr << "client writeError: " << ex.what() << endl;
139 }
140
getReadBuffer(void ** bufReturn,size_t * lenReturn)141 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
142 *bufReturn = readbuf_ + bytesRead_;
143 *lenReturn = sizeof(readbuf_) - bytesRead_;
144 }
readEOF()145 void readEOF() noexcept override { cerr << "client readEOF" << endl; }
146
readErr(const AsyncSocketException & ex)147 void readErr(const AsyncSocketException& ex) noexcept override {
148 cerr << "client readError: " << ex.what() << endl;
149 promise_.setException(ex);
150 }
151
readDataAvailable(size_t len)152 void readDataAvailable(size_t len) noexcept override {
153 EXPECT_TRUE(t1_.getEventBase()->isInEventBaseThread());
154 EXPECT_EQ(sslSocket_->getEventBase(), t1_.getEventBase());
155 cerr << "client read data: " << len << endl;
156 bytesRead_ += len;
157 if (len == sizeof(buf_)) {
158 EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
159 sslSocket_->closeNow();
160 sslSocket_.reset();
161 promise_.setValue(true);
162 }
163 }
164 };
165
166 /**
167 * Test passing contexts between threads
168 */
TEST(AsyncSSLSocketTest2,AttachDetachSSLContext)169 TEST(AsyncSSLSocketTest2, AttachDetachSSLContext) {
170 // Start listening on a local port
171 WriteCallbackBase writeCallback;
172 ReadCallback readCallback(&writeCallback);
173 HandshakeCallback handshakeCallback(&readCallback);
174 SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
175 TestSSLServer server(&acceptCallback);
176
177 std::shared_ptr<AttachDetachClient> client(
178 new AttachDetachClient(server.getAddress()));
179
180 auto f = client->getFuture();
181 client->connect();
182 EXPECT_TRUE(std::move(f).within(std::chrono::seconds(3)).get());
183 }
184
185 class ConnectClient : public AsyncSocket::ConnectCallback {
186 public:
187 ConnectClient() = default;
188
getFuture()189 Future<bool> getFuture() { return promise_.getFuture(); }
190
connect(const folly::SocketAddress & addr)191 void connect(const folly::SocketAddress& addr) {
192 t1_.getEventBase()->runInEventBaseThread([&] {
193 socket_ = t1_.createSocket();
194 socket_->connect(this, addr);
195 });
196 }
197
connectSuccess()198 void connectSuccess() noexcept override {
199 socket_.reset();
200 promise_.setValue(true);
201 }
202
connectErr(const AsyncSocketException &)203 void connectErr(const AsyncSocketException& /* ex */) noexcept override {
204 socket_.reset();
205 promise_.setValue(false);
206 }
207
setCtx(std::shared_ptr<SSLContext> ctx)208 void setCtx(std::shared_ptr<SSLContext> ctx) { t1_.ctx_ = ctx; }
209
210 private:
211 EvbAndContext t1_;
212 // promise to fulfill when done with a value of true if connect succeeded
213 folly::Promise<bool> promise_;
214 std::shared_ptr<AsyncSSLSocket> socket_;
215 };
216
217 class NoopReadCallback : public ReadCallbackBase {
218 public:
NoopReadCallback()219 NoopReadCallback() : ReadCallbackBase(nullptr) { state = STATE_SUCCEEDED; }
220
getReadBuffer(void ** buf,size_t * lenReturn)221 void getReadBuffer(void** buf, size_t* lenReturn) override {
222 *buf = &buffer_;
223 *lenReturn = 1;
224 }
readDataAvailable(size_t)225 void readDataAvailable(size_t) noexcept override {}
226
227 uint8_t buffer_{0};
228 };
229
TEST(AsyncSSLSocketTest2,TestTLS12DefaultClient)230 TEST(AsyncSSLSocketTest2, TestTLS12DefaultClient) {
231 // Start listening on a local port
232 NoopReadCallback readCallback;
233 HandshakeCallback handshakeCallback(&readCallback);
234 SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
235 auto ctx = std::make_shared<SSLContext>(SSLContext::TLSv1_2);
236 TestSSLServer server(&acceptCallback, ctx);
237 server.loadTestCerts();
238
239 // create a default client
240 auto c1 = std::make_unique<ConnectClient>();
241 auto f1 = c1->getFuture();
242 c1->connect(server.getAddress());
243 EXPECT_TRUE(std::move(f1).within(std::chrono::seconds(3)).get());
244 }
245
TEST(AsyncSSLSocketTest2,TestTLS12BadClient)246 TEST(AsyncSSLSocketTest2, TestTLS12BadClient) {
247 // Start listening on a local port
248 NoopReadCallback readCallback;
249 HandshakeCallback handshakeCallback(
250 &readCallback, HandshakeCallback::EXPECT_ERROR);
251 SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
252 auto ctx = std::make_shared<SSLContext>(SSLContext::TLSv1_2);
253 TestSSLServer server(&acceptCallback, ctx);
254 server.loadTestCerts();
255
256 // create a client that doesn't speak TLS 1.2
257 auto c2 = std::make_unique<ConnectClient>();
258 auto clientCtx = std::make_shared<SSLContext>();
259 clientCtx->setOptions(SSL_OP_NO_TLSv1_2);
260 c2->setCtx(clientCtx);
261 auto f2 = c2->getFuture();
262 c2->connect(server.getAddress());
263 EXPECT_FALSE(std::move(f2).within(std::chrono::seconds(3)).get());
264 }
265
266 } // namespace folly
267
main(int argc,char * argv[])268 int main(int argc, char* argv[]) {
269 folly::ssl::init();
270 #ifdef SIGPIPE
271 signal(SIGPIPE, SIG_IGN);
272 #endif
273 testing::InitGoogleTest(&argc, argv);
274 folly::init(&argc, &argv);
275 return RUN_ALL_TESTS();
276 OPENSSL_cleanup();
277 }
278