1 /* 2 * Copyright 2004 The WebRTC Project Authors. All rights reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #ifndef RTC_BASE_OPENSSLADAPTER_H_ 12 #define RTC_BASE_OPENSSLADAPTER_H_ 13 14 #include <map> 15 #include <string> 16 #include "rtc_base/buffer.h" 17 #include "rtc_base/messagehandler.h" 18 #include "rtc_base/messagequeue.h" 19 #include "rtc_base/opensslidentity.h" 20 #include "rtc_base/ssladapter.h" 21 22 typedef struct ssl_st SSL; 23 typedef struct ssl_ctx_st SSL_CTX; 24 typedef struct x509_store_ctx_st X509_STORE_CTX; 25 typedef struct ssl_session_st SSL_SESSION; 26 27 namespace rtc { 28 29 class OpenSSLAdapterFactory; 30 31 class OpenSSLAdapter : public SSLAdapter, public MessageHandler { 32 public: 33 static bool InitializeSSL(VerificationCallback callback); 34 static bool InitializeSSLThread(); 35 static bool CleanupSSL(); 36 37 explicit OpenSSLAdapter(AsyncSocket* socket, 38 OpenSSLAdapterFactory* factory = nullptr); 39 ~OpenSSLAdapter() override; 40 41 void SetIgnoreBadCert(bool ignore) override; 42 void SetAlpnProtocols(const std::vector<std::string>& protos) override; 43 void SetEllipticCurves(const std::vector<std::string>& curves) override; 44 45 void SetMode(SSLMode mode) override; 46 void SetIdentity(SSLIdentity* identity) override; 47 void SetRole(SSLRole role) override; 48 AsyncSocket* Accept(SocketAddress* paddr) override; 49 int StartSSL(const char* hostname, bool restartable) override; 50 int Send(const void* pv, size_t cb) override; 51 int SendTo(const void* pv, size_t cb, const SocketAddress& addr) override; 52 int Recv(void* pv, size_t cb, int64_t* timestamp) override; 53 int RecvFrom(void* pv, 54 size_t cb, 55 SocketAddress* paddr, 56 int64_t* timestamp) override; 57 int Close() override; 58 59 // Note that the socket returns ST_CONNECTING while SSL is being negotiated. 60 ConnState GetState() const override; 61 bool IsResumedSession() override; 62 63 // Creates a new SSL_CTX object, configured for client-to-server usage 64 // with SSLMode |mode|, and if |enable_cache| is true, with support for 65 // storing successful sessions so that they can be later resumed. 66 // OpenSSLAdapterFactory will call this method to create its own internal 67 // SSL_CTX, and OpenSSLAdapter will also call this when used without a 68 // factory. 69 static SSL_CTX* CreateContext(SSLMode mode, bool enable_cache); 70 71 protected: 72 void OnConnectEvent(AsyncSocket* socket) override; 73 void OnReadEvent(AsyncSocket* socket) override; 74 void OnWriteEvent(AsyncSocket* socket) override; 75 void OnCloseEvent(AsyncSocket* socket, int err) override; 76 77 private: 78 enum SSLState { 79 SSL_NONE, SSL_WAIT, SSL_CONNECTING, SSL_CONNECTED, SSL_ERROR 80 }; 81 82 enum { MSG_TIMEOUT }; 83 84 int BeginSSL(); 85 int ContinueSSL(); 86 void Error(const char* context, int err, bool signal = true); 87 void Cleanup(); 88 89 // Return value and arguments have the same meanings as for Send; |error| is 90 // an output parameter filled with the result of SSL_get_error. 91 int DoSslWrite(const void* pv, size_t cb, int* error); 92 93 void OnMessage(Message* msg) override; 94 95 static bool VerifyServerName(SSL* ssl, const char* host, 96 bool ignore_bad_cert); 97 bool SSLPostConnectionCheck(SSL* ssl, const char* host); 98 #if !defined(NDEBUG) 99 // In debug builds, logs info about the state of the SSL connection. 100 static void SSLInfoCallback(const SSL* ssl, int where, int ret); 101 #endif 102 static int SSLVerifyCallback(int ok, X509_STORE_CTX* store); 103 static VerificationCallback custom_verify_callback_; 104 friend class OpenSSLStreamAdapter; // for custom_verify_callback_; 105 106 // If the SSL_CTX was created with |enable_cache| set to true, this callback 107 // will be called when a SSL session has been successfully established, 108 // to allow its SSL_SESSION* to be cached for later resumption. 109 static int NewSSLSessionCallback(SSL* ssl, SSL_SESSION* session); 110 111 static bool ConfigureTrustedRootCertificates(SSL_CTX* ctx); 112 113 // Parent object that maintains shared state. 114 // Can be null if state sharing is not needed. 115 OpenSSLAdapterFactory* factory_; 116 117 SSLState state_; 118 std::unique_ptr<OpenSSLIdentity> identity_; 119 SSLRole role_; 120 bool ssl_read_needs_write_; 121 bool ssl_write_needs_read_; 122 // If true, socket will retain SSL configuration after Close. 123 // TODO(juberti): Remove this unused flag. 124 bool restartable_; 125 126 // This buffer is used if SSL_write fails with SSL_ERROR_WANT_WRITE, which 127 // means we need to keep retrying with *the same exact data* until it 128 // succeeds. Afterwards it will be cleared. 129 Buffer pending_data_; 130 131 SSL* ssl_; 132 SSL_CTX* ssl_ctx_; 133 std::string ssl_host_name_; 134 // Do DTLS or not 135 SSLMode ssl_mode_; 136 // If true, the server certificate need not match the configured hostname. 137 bool ignore_bad_cert_; 138 // List of protocols to be used in the TLS ALPN extension. 139 std::vector<std::string> alpn_protocols_; 140 // List of elliptic curves to be used in the TLS elliptic curves extension. 141 std::vector<std::string> elliptic_curves_; 142 143 bool custom_verification_succeeded_; 144 }; 145 146 std::string TransformAlpnProtocols(const std::vector<std::string>& protos); 147 148 ///////////////////////////////////////////////////////////////////////////// 149 class OpenSSLAdapterFactory : public SSLAdapterFactory { 150 public: 151 OpenSSLAdapterFactory(); 152 ~OpenSSLAdapterFactory() override; 153 154 void SetMode(SSLMode mode) override; 155 OpenSSLAdapter* CreateAdapter(AsyncSocket* socket) override; 156 157 static OpenSSLAdapterFactory* Create(); 158 159 private: ssl_ctx()160 SSL_CTX* ssl_ctx() { return ssl_ctx_; } 161 // Looks up a session by hostname. The returned SSL_SESSION is not up_refed. 162 SSL_SESSION* LookupSession(const std::string& hostname); 163 // Adds a session to the cache, and up_refs it. Any existing session with the 164 // same hostname is replaced. 165 void AddSession(const std::string& hostname, SSL_SESSION* session); 166 friend class OpenSSLAdapter; 167 168 SSLMode ssl_mode_; 169 // Holds the shared SSL_CTX for all created adapters. 170 SSL_CTX* ssl_ctx_; 171 // Map of hostnames to SSL_SESSIONs; holds references to the SSL_SESSIONs, 172 // which are cleaned up when the factory is destroyed. 173 // TODO(juberti): Add LRU eviction to keep the cache from growing forever. 174 std::map<std::string, SSL_SESSION*> sessions_; 175 }; 176 177 } // namespace rtc 178 179 #endif // RTC_BASE_OPENSSLADAPTER_H_ 180