1 /*
2  *  Copyright 2014 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <string>
12 
13 #include "webrtc/base/gunit.h"
14 #include "webrtc/base/ipaddress.h"
15 #include "webrtc/base/socketstream.h"
16 #include "webrtc/base/ssladapter.h"
17 #include "webrtc/base/sslstreamadapter.h"
18 #include "webrtc/base/stream.h"
19 #include "webrtc/base/virtualsocketserver.h"
20 
21 static const int kTimeout = 5000;
22 
23 static rtc::AsyncSocket* CreateSocket(const rtc::SSLMode& ssl_mode) {
24   rtc::SocketAddress address(rtc::IPAddress(INADDR_ANY), 0);
25 
26   rtc::AsyncSocket* socket = rtc::Thread::Current()->
27       socketserver()->CreateAsyncSocket(
28       address.family(), (ssl_mode == rtc::SSL_MODE_DTLS) ?
29       SOCK_DGRAM : SOCK_STREAM);
30   socket->Bind(address);
31 
32   return socket;
33 }
34 
35 static std::string GetSSLProtocolName(const rtc::SSLMode& ssl_mode) {
36   return (ssl_mode == rtc::SSL_MODE_DTLS) ? "DTLS" : "TLS";
37 }
38 
39 class SSLAdapterTestDummyClient : public sigslot::has_slots<> {
40  public:
41   explicit SSLAdapterTestDummyClient(const rtc::SSLMode& ssl_mode)
42       : ssl_mode_(ssl_mode) {
43     rtc::AsyncSocket* socket = CreateSocket(ssl_mode_);
44 
45     ssl_adapter_.reset(rtc::SSLAdapter::Create(socket));
46 
47     ssl_adapter_->SetMode(ssl_mode_);
48 
49     // Ignore any certificate errors for the purpose of testing.
50     // Note: We do this only because we don't have a real certificate.
51     // NEVER USE THIS IN PRODUCTION CODE!
52     ssl_adapter_->set_ignore_bad_cert(true);
53 
54     ssl_adapter_->SignalReadEvent.connect(this,
55         &SSLAdapterTestDummyClient::OnSSLAdapterReadEvent);
56     ssl_adapter_->SignalCloseEvent.connect(this,
57         &SSLAdapterTestDummyClient::OnSSLAdapterCloseEvent);
58   }
59 
60   rtc::SocketAddress GetAddress() const {
61     return ssl_adapter_->GetLocalAddress();
62   }
63 
64   rtc::AsyncSocket::ConnState GetState() const {
65     return ssl_adapter_->GetState();
66   }
67 
68   const std::string& GetReceivedData() const {
69     return data_;
70   }
71 
72   int Connect(const std::string& hostname, const rtc::SocketAddress& address) {
73     LOG(LS_INFO) << "Initiating connection with " << address;
74 
75     int rv = ssl_adapter_->Connect(address);
76 
77     if (rv == 0) {
78       LOG(LS_INFO) << "Starting " << GetSSLProtocolName(ssl_mode_)
79           << " handshake with " << hostname;
80 
81       if (ssl_adapter_->StartSSL(hostname.c_str(), false) != 0) {
82         return -1;
83       }
84     }
85 
86     return rv;
87   }
88 
89   int Close() {
90     return ssl_adapter_->Close();
91   }
92 
93   int Send(const std::string& message) {
94     LOG(LS_INFO) << "Client sending '" << message << "'";
95 
96     return ssl_adapter_->Send(message.data(), message.length());
97   }
98 
99   void OnSSLAdapterReadEvent(rtc::AsyncSocket* socket) {
100     char buffer[4096] = "";
101 
102     // Read data received from the server and store it in our internal buffer.
103     int read = socket->Recv(buffer, sizeof(buffer) - 1);
104     if (read != -1) {
105       buffer[read] = '\0';
106 
107       LOG(LS_INFO) << "Client received '" << buffer << "'";
108 
109       data_ += buffer;
110     }
111   }
112 
113   void OnSSLAdapterCloseEvent(rtc::AsyncSocket* socket, int error) {
114     // OpenSSLAdapter signals handshake failure with a close event, but without
115     // closing the socket! Let's close the socket here. This way GetState() can
116     // return CS_CLOSED after failure.
117     if (socket->GetState() != rtc::AsyncSocket::CS_CLOSED) {
118       socket->Close();
119     }
120   }
121 
122  private:
123   const rtc::SSLMode ssl_mode_;
124 
125   rtc::scoped_ptr<rtc::SSLAdapter> ssl_adapter_;
126 
127   std::string data_;
128 };
129 
130 class SSLAdapterTestDummyServer : public sigslot::has_slots<> {
131  public:
132   explicit SSLAdapterTestDummyServer(const rtc::SSLMode& ssl_mode)
133       : ssl_mode_(ssl_mode) {
134     // Generate a key pair and a certificate for this host.
135     ssl_identity_.reset(rtc::SSLIdentity::Generate(GetHostname()));
136 
137     server_socket_.reset(CreateSocket(ssl_mode_));
138 
139     if (ssl_mode_ == rtc::SSL_MODE_TLS) {
140       server_socket_->SignalReadEvent.connect(this,
141           &SSLAdapterTestDummyServer::OnServerSocketReadEvent);
142 
143       server_socket_->Listen(1);
144     }
145 
146     LOG(LS_INFO) << ((ssl_mode_ == rtc::SSL_MODE_DTLS) ? "UDP" : "TCP")
147         << " server listening on " << server_socket_->GetLocalAddress();
148   }
149 
150   rtc::SocketAddress GetAddress() const {
151     return server_socket_->GetLocalAddress();
152   }
153 
154   std::string GetHostname() const {
155     // Since we don't have a real certificate anyway, the value here doesn't
156     // really matter.
157     return "example.com";
158   }
159 
160   const std::string& GetReceivedData() const {
161     return data_;
162   }
163 
164   int Send(const std::string& message) {
165     if (ssl_stream_adapter_ == NULL
166         || ssl_stream_adapter_->GetState() != rtc::SS_OPEN) {
167       // No connection yet.
168       return -1;
169     }
170 
171     LOG(LS_INFO) << "Server sending '" << message << "'";
172 
173     size_t written;
174     int error;
175 
176     rtc::StreamResult r = ssl_stream_adapter_->Write(message.data(),
177         message.length(), &written, &error);
178     if (r == rtc::SR_SUCCESS) {
179       return written;
180     } else {
181       return -1;
182     }
183   }
184 
185   void AcceptConnection(const rtc::SocketAddress& address) {
186     // Only a single connection is supported.
187     ASSERT_TRUE(ssl_stream_adapter_ == NULL);
188 
189     // This is only for DTLS.
190     ASSERT_EQ(rtc::SSL_MODE_DTLS, ssl_mode_);
191 
192     // Transfer ownership of the socket to the SSLStreamAdapter object.
193     rtc::AsyncSocket* socket = server_socket_.release();
194 
195     socket->Connect(address);
196 
197     DoHandshake(socket);
198   }
199 
200   void OnServerSocketReadEvent(rtc::AsyncSocket* socket) {
201     // Only a single connection is supported.
202     ASSERT_TRUE(ssl_stream_adapter_ == NULL);
203 
204     DoHandshake(server_socket_->Accept(NULL));
205   }
206 
207   void OnSSLStreamAdapterEvent(rtc::StreamInterface* stream, int sig, int err) {
208     if (sig & rtc::SE_READ) {
209       char buffer[4096] = "";
210 
211       size_t read;
212       int error;
213 
214       // Read data received from the client and store it in our internal
215       // buffer.
216       rtc::StreamResult r = stream->Read(buffer,
217           sizeof(buffer) - 1, &read, &error);
218       if (r == rtc::SR_SUCCESS) {
219         buffer[read] = '\0';
220 
221         LOG(LS_INFO) << "Server received '" << buffer << "'";
222 
223         data_ += buffer;
224       }
225     }
226   }
227 
228  private:
229   void DoHandshake(rtc::AsyncSocket* socket) {
230     rtc::SocketStream* stream = new rtc::SocketStream(socket);
231 
232     ssl_stream_adapter_.reset(rtc::SSLStreamAdapter::Create(stream));
233 
234     ssl_stream_adapter_->SetMode(ssl_mode_);
235     ssl_stream_adapter_->SetServerRole();
236 
237     // SSLStreamAdapter is normally used for peer-to-peer communication, but
238     // here we're testing communication between a client and a server
239     // (e.g. a WebRTC-based application and an RFC 5766 TURN server), where
240     // clients are not required to provide a certificate during handshake.
241     // Accordingly, we must disable client authentication here.
242     ssl_stream_adapter_->set_client_auth_enabled(false);
243 
244     ssl_stream_adapter_->SetIdentity(ssl_identity_->GetReference());
245 
246     // Set a bogus peer certificate digest.
247     unsigned char digest[20];
248     size_t digest_len = sizeof(digest);
249     ssl_stream_adapter_->SetPeerCertificateDigest(rtc::DIGEST_SHA_1, digest,
250         digest_len);
251 
252     ssl_stream_adapter_->StartSSLWithPeer();
253 
254     ssl_stream_adapter_->SignalEvent.connect(this,
255         &SSLAdapterTestDummyServer::OnSSLStreamAdapterEvent);
256   }
257 
258   const rtc::SSLMode ssl_mode_;
259 
260   rtc::scoped_ptr<rtc::AsyncSocket> server_socket_;
261   rtc::scoped_ptr<rtc::SSLStreamAdapter> ssl_stream_adapter_;
262 
263   rtc::scoped_ptr<rtc::SSLIdentity> ssl_identity_;
264 
265   std::string data_;
266 };
267 
268 class SSLAdapterTestBase : public testing::Test,
269                            public sigslot::has_slots<> {
270  public:
271   explicit SSLAdapterTestBase(const rtc::SSLMode& ssl_mode)
272       : ssl_mode_(ssl_mode),
273         ss_scope_(new rtc::VirtualSocketServer(NULL)),
274         server_(new SSLAdapterTestDummyServer(ssl_mode_)),
275         client_(new SSLAdapterTestDummyClient(ssl_mode_)),
276         handshake_wait_(kTimeout) {
277   }
278 
279   void SetHandshakeWait(int wait) {
280     handshake_wait_ = wait;
281   }
282 
283   void TestHandshake(bool expect_success) {
284     int rv;
285 
286     // The initial state is CS_CLOSED
287     ASSERT_EQ(rtc::AsyncSocket::CS_CLOSED, client_->GetState());
288 
289     rv = client_->Connect(server_->GetHostname(), server_->GetAddress());
290     ASSERT_EQ(0, rv);
291 
292     // Now the state should be CS_CONNECTING
293     ASSERT_EQ(rtc::AsyncSocket::CS_CONNECTING, client_->GetState());
294 
295     if (ssl_mode_ == rtc::SSL_MODE_DTLS) {
296       // For DTLS, call AcceptConnection() with the client's address.
297       server_->AcceptConnection(client_->GetAddress());
298     }
299 
300     if (expect_success) {
301       // If expecting success, the client should end up in the CS_CONNECTED
302       // state after handshake.
303       EXPECT_EQ_WAIT(rtc::AsyncSocket::CS_CONNECTED, client_->GetState(),
304           handshake_wait_);
305 
306       LOG(LS_INFO) << GetSSLProtocolName(ssl_mode_) << " handshake complete.";
307 
308     } else {
309       // On handshake failure the client should end up in the CS_CLOSED state.
310       EXPECT_EQ_WAIT(rtc::AsyncSocket::CS_CLOSED, client_->GetState(),
311           handshake_wait_);
312 
313       LOG(LS_INFO) << GetSSLProtocolName(ssl_mode_) << " handshake failed.";
314     }
315   }
316 
317   void TestTransfer(const std::string& message) {
318     int rv;
319 
320     rv = client_->Send(message);
321     ASSERT_EQ(static_cast<int>(message.length()), rv);
322 
323     // The server should have received the client's message.
324     EXPECT_EQ_WAIT(message, server_->GetReceivedData(), kTimeout);
325 
326     rv = server_->Send(message);
327     ASSERT_EQ(static_cast<int>(message.length()), rv);
328 
329     // The client should have received the server's message.
330     EXPECT_EQ_WAIT(message, client_->GetReceivedData(), kTimeout);
331 
332     LOG(LS_INFO) << "Transfer complete.";
333   }
334 
335  private:
336   const rtc::SSLMode ssl_mode_;
337 
338   const rtc::SocketServerScope ss_scope_;
339 
340   rtc::scoped_ptr<SSLAdapterTestDummyServer> server_;
341   rtc::scoped_ptr<SSLAdapterTestDummyClient> client_;
342 
343   int handshake_wait_;
344 };
345 
346 class SSLAdapterTestTLS : public SSLAdapterTestBase {
347  public:
348   SSLAdapterTestTLS() : SSLAdapterTestBase(rtc::SSL_MODE_TLS) {}
349 };
350 
351 class SSLAdapterTestDTLS : public SSLAdapterTestBase {
352  public:
353   SSLAdapterTestDTLS() : SSLAdapterTestBase(rtc::SSL_MODE_DTLS) {}
354 };
355 
356 #if SSL_USE_OPENSSL
357 
358 // Basic tests: TLS
359 
360 // Test that handshake works
361 TEST_F(SSLAdapterTestTLS, TestTLSConnect) {
362   TestHandshake(true);
363 }
364 
365 // Test transfer between client and server
366 TEST_F(SSLAdapterTestTLS, TestTLSTransfer) {
367   TestHandshake(true);
368   TestTransfer("Hello, world!");
369 }
370 
371 // Basic tests: DTLS
372 
373 // Test that handshake works
374 TEST_F(SSLAdapterTestDTLS, TestDTLSConnect) {
375   TestHandshake(true);
376 }
377 
378 // Test transfer between client and server
379 TEST_F(SSLAdapterTestDTLS, TestDTLSTransfer) {
380   TestHandshake(true);
381   TestTransfer("Hello, world!");
382 }
383 
384 #endif  // SSL_USE_OPENSSL
385 
386