1 /*
2  *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #ifndef RTC_BASE_OPENSSLADAPTER_H_
12 #define RTC_BASE_OPENSSLADAPTER_H_
13 
14 #include <map>
15 #include <string>
16 #include "rtc_base/buffer.h"
17 #include "rtc_base/messagehandler.h"
18 #include "rtc_base/messagequeue.h"
19 #include "rtc_base/opensslidentity.h"
20 #include "rtc_base/ssladapter.h"
21 
22 typedef struct ssl_st SSL;
23 typedef struct ssl_ctx_st SSL_CTX;
24 typedef struct x509_store_ctx_st X509_STORE_CTX;
25 typedef struct ssl_session_st SSL_SESSION;
26 
27 namespace rtc {
28 
29 class OpenSSLAdapterFactory;
30 
31 class OpenSSLAdapter : public SSLAdapter, public MessageHandler {
32  public:
33   static bool InitializeSSL(VerificationCallback callback);
34   static bool InitializeSSLThread();
35   static bool CleanupSSL();
36 
37   explicit OpenSSLAdapter(AsyncSocket* socket,
38                           OpenSSLAdapterFactory* factory = nullptr);
39   ~OpenSSLAdapter() override;
40 
41   void SetIgnoreBadCert(bool ignore) override;
42   void SetAlpnProtocols(const std::vector<std::string>& protos) override;
43   void SetEllipticCurves(const std::vector<std::string>& curves) override;
44 
45   void SetMode(SSLMode mode) override;
46   void SetIdentity(SSLIdentity* identity) override;
47   void SetRole(SSLRole role) override;
48   AsyncSocket* Accept(SocketAddress* paddr) override;
49   int StartSSL(const char* hostname, bool restartable) override;
50   int Send(const void* pv, size_t cb) override;
51   int SendTo(const void* pv, size_t cb, const SocketAddress& addr) override;
52   int Recv(void* pv, size_t cb, int64_t* timestamp) override;
53   int RecvFrom(void* pv,
54                size_t cb,
55                SocketAddress* paddr,
56                int64_t* timestamp) override;
57   int Close() override;
58 
59   // Note that the socket returns ST_CONNECTING while SSL is being negotiated.
60   ConnState GetState() const override;
61   bool IsResumedSession() override;
62 
63   // Creates a new SSL_CTX object, configured for client-to-server usage
64   // with SSLMode |mode|, and if |enable_cache| is true, with support for
65   // storing successful sessions so that they can be later resumed.
66   // OpenSSLAdapterFactory will call this method to create its own internal
67   // SSL_CTX, and OpenSSLAdapter will also call this when used without a
68   // factory.
69   static SSL_CTX* CreateContext(SSLMode mode, bool enable_cache);
70 
71  protected:
72   void OnConnectEvent(AsyncSocket* socket) override;
73   void OnReadEvent(AsyncSocket* socket) override;
74   void OnWriteEvent(AsyncSocket* socket) override;
75   void OnCloseEvent(AsyncSocket* socket, int err) override;
76 
77  private:
78   enum SSLState {
79     SSL_NONE, SSL_WAIT, SSL_CONNECTING, SSL_CONNECTED, SSL_ERROR
80   };
81 
82   enum { MSG_TIMEOUT };
83 
84   int BeginSSL();
85   int ContinueSSL();
86   void Error(const char* context, int err, bool signal = true);
87   void Cleanup();
88 
89   // Return value and arguments have the same meanings as for Send; |error| is
90   // an output parameter filled with the result of SSL_get_error.
91   int DoSslWrite(const void* pv, size_t cb, int* error);
92 
93   void OnMessage(Message* msg) override;
94 
95   static bool VerifyServerName(SSL* ssl, const char* host,
96                                bool ignore_bad_cert);
97   bool SSLPostConnectionCheck(SSL* ssl, const char* host);
98 #if !defined(NDEBUG)
99   // In debug builds, logs info about the state of the SSL connection.
100   static void SSLInfoCallback(const SSL* ssl, int where, int ret);
101 #endif
102   static int SSLVerifyCallback(int ok, X509_STORE_CTX* store);
103   static VerificationCallback custom_verify_callback_;
104   friend class OpenSSLStreamAdapter;  // for custom_verify_callback_;
105 
106   // If the SSL_CTX was created with |enable_cache| set to true, this callback
107   // will be called when a SSL session has been successfully established,
108   // to allow its SSL_SESSION* to be cached for later resumption.
109   static int NewSSLSessionCallback(SSL* ssl, SSL_SESSION* session);
110 
111   static bool ConfigureTrustedRootCertificates(SSL_CTX* ctx);
112 
113   // Parent object that maintains shared state.
114   // Can be null if state sharing is not needed.
115   OpenSSLAdapterFactory* factory_;
116 
117   SSLState state_;
118   std::unique_ptr<OpenSSLIdentity> identity_;
119   SSLRole role_;
120   bool ssl_read_needs_write_;
121   bool ssl_write_needs_read_;
122   // If true, socket will retain SSL configuration after Close.
123   // TODO(juberti): Remove this unused flag.
124   bool restartable_;
125 
126   // This buffer is used if SSL_write fails with SSL_ERROR_WANT_WRITE, which
127   // means we need to keep retrying with *the same exact data* until it
128   // succeeds. Afterwards it will be cleared.
129   Buffer pending_data_;
130 
131   SSL* ssl_;
132   SSL_CTX* ssl_ctx_;
133   std::string ssl_host_name_;
134   // Do DTLS or not
135   SSLMode ssl_mode_;
136   // If true, the server certificate need not match the configured hostname.
137   bool ignore_bad_cert_;
138   // List of protocols to be used in the TLS ALPN extension.
139   std::vector<std::string> alpn_protocols_;
140   // List of elliptic curves to be used in the TLS elliptic curves extension.
141   std::vector<std::string> elliptic_curves_;
142 
143   bool custom_verification_succeeded_;
144 };
145 
146 std::string TransformAlpnProtocols(const std::vector<std::string>& protos);
147 
148 /////////////////////////////////////////////////////////////////////////////
149 class OpenSSLAdapterFactory : public SSLAdapterFactory {
150  public:
151   OpenSSLAdapterFactory();
152   ~OpenSSLAdapterFactory() override;
153 
154   void SetMode(SSLMode mode) override;
155   OpenSSLAdapter* CreateAdapter(AsyncSocket* socket) override;
156 
157   static OpenSSLAdapterFactory* Create();
158 
159  private:
ssl_ctx()160   SSL_CTX* ssl_ctx() { return ssl_ctx_; }
161   // Looks up a session by hostname. The returned SSL_SESSION is not up_refed.
162   SSL_SESSION* LookupSession(const std::string& hostname);
163   // Adds a session to the cache, and up_refs it. Any existing session with the
164   // same hostname is replaced.
165   void AddSession(const std::string& hostname, SSL_SESSION* session);
166   friend class OpenSSLAdapter;
167 
168   SSLMode ssl_mode_;
169   // Holds the shared SSL_CTX for all created adapters.
170   SSL_CTX* ssl_ctx_;
171   // Map of hostnames to SSL_SESSIONs; holds references to the SSL_SESSIONs,
172   // which are cleaned up when the factory is destroyed.
173   // TODO(juberti): Add LRU eviction to keep the cache from growing forever.
174   std::map<std::string, SSL_SESSION*> sessions_;
175 };
176 
177 }  // namespace rtc
178 
179 #endif  // RTC_BASE_OPENSSLADAPTER_H_
180