1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/ssl_client_socket.h"
6
7 #include <string>
8
9 #include "base/logging.h"
10 #include "net/socket/ssl_client_socket_impl.h"
11 #include "net/socket/stream_socket.h"
12 #include "net/ssl/ssl_client_session_cache.h"
13 #include "net/ssl/ssl_key_logger.h"
14
15 namespace net {
16
SSLClientSocket()17 SSLClientSocket::SSLClientSocket()
18 : signed_cert_timestamps_received_(false),
19 stapled_ocsp_response_received_(false) {}
20
21 // static
SetSSLKeyLogger(std::unique_ptr<SSLKeyLogger> logger)22 void SSLClientSocket::SetSSLKeyLogger(std::unique_ptr<SSLKeyLogger> logger) {
23 SSLClientSocketImpl::SetSSLKeyLogger(std::move(logger));
24 }
25
26 // static
SerializeNextProtos(const NextProtoVector & next_protos)27 std::vector<uint8_t> SSLClientSocket::SerializeNextProtos(
28 const NextProtoVector& next_protos) {
29 std::vector<uint8_t> wire_protos;
30 for (const NextProto next_proto : next_protos) {
31 const std::string proto = NextProtoToString(next_proto);
32 if (proto.size() > 255) {
33 LOG(WARNING) << "Ignoring overlong ALPN protocol: " << proto;
34 continue;
35 }
36 if (proto.size() == 0) {
37 LOG(WARNING) << "Ignoring empty ALPN protocol";
38 continue;
39 }
40 wire_protos.push_back(proto.size());
41 for (const char ch : proto) {
42 wire_protos.push_back(static_cast<uint8_t>(ch));
43 }
44 }
45
46 return wire_protos;
47 }
48
SSLClientContext(SSLConfigService * ssl_config_service,CertVerifier * cert_verifier,TransportSecurityState * transport_security_state,CTVerifier * cert_transparency_verifier,CTPolicyEnforcer * ct_policy_enforcer,SSLClientSessionCache * ssl_client_session_cache)49 SSLClientContext::SSLClientContext(
50 SSLConfigService* ssl_config_service,
51 CertVerifier* cert_verifier,
52 TransportSecurityState* transport_security_state,
53 CTVerifier* cert_transparency_verifier,
54 CTPolicyEnforcer* ct_policy_enforcer,
55 SSLClientSessionCache* ssl_client_session_cache)
56 : ssl_config_service_(ssl_config_service),
57 cert_verifier_(cert_verifier),
58 transport_security_state_(transport_security_state),
59 cert_transparency_verifier_(cert_transparency_verifier),
60 ct_policy_enforcer_(ct_policy_enforcer),
61 ssl_client_session_cache_(ssl_client_session_cache) {
62 CHECK(cert_verifier_);
63 CHECK(transport_security_state_);
64 CHECK(cert_transparency_verifier_);
65 CHECK(ct_policy_enforcer_);
66
67 if (ssl_config_service_) {
68 config_ = ssl_config_service_->GetSSLContextConfig();
69 ssl_config_service_->AddObserver(this);
70 }
71 CertDatabase::GetInstance()->AddObserver(this);
72 }
73
~SSLClientContext()74 SSLClientContext::~SSLClientContext() {
75 if (ssl_config_service_) {
76 ssl_config_service_->RemoveObserver(this);
77 }
78 CertDatabase::GetInstance()->RemoveObserver(this);
79 }
80
CreateSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config)81 std::unique_ptr<SSLClientSocket> SSLClientContext::CreateSSLClientSocket(
82 std::unique_ptr<StreamSocket> stream_socket,
83 const HostPortPair& host_and_port,
84 const SSLConfig& ssl_config) {
85 return std::make_unique<SSLClientSocketImpl>(this, std::move(stream_socket),
86 host_and_port, ssl_config);
87 }
88
GetClientCertificate(const HostPortPair & server,scoped_refptr<X509Certificate> * client_cert,scoped_refptr<SSLPrivateKey> * private_key)89 bool SSLClientContext::GetClientCertificate(
90 const HostPortPair& server,
91 scoped_refptr<X509Certificate>* client_cert,
92 scoped_refptr<SSLPrivateKey>* private_key) {
93 return ssl_client_auth_cache_.Lookup(server, client_cert, private_key);
94 }
95
SetClientCertificate(const HostPortPair & server,scoped_refptr<X509Certificate> client_cert,scoped_refptr<SSLPrivateKey> private_key)96 void SSLClientContext::SetClientCertificate(
97 const HostPortPair& server,
98 scoped_refptr<X509Certificate> client_cert,
99 scoped_refptr<SSLPrivateKey> private_key) {
100 ssl_client_auth_cache_.Add(server, std::move(client_cert),
101 std::move(private_key));
102
103 if (ssl_client_session_cache_) {
104 // Session resumption bypasses client certificate negotiation, so flush all
105 // associated sessions when preferences change.
106 ssl_client_session_cache_->FlushForServer(server);
107 }
108 NotifySSLConfigForServerChanged(server);
109 }
110
ClearClientCertificate(const HostPortPair & server)111 bool SSLClientContext::ClearClientCertificate(const HostPortPair& server) {
112 if (!ssl_client_auth_cache_.Remove(server)) {
113 return false;
114 }
115
116 if (ssl_client_session_cache_) {
117 // Session resumption bypasses client certificate negotiation, so flush all
118 // associated sessions when preferences change.
119 ssl_client_session_cache_->FlushForServer(server);
120 }
121 NotifySSLConfigForServerChanged(server);
122 return true;
123 }
124
AddObserver(Observer * observer)125 void SSLClientContext::AddObserver(Observer* observer) {
126 observers_.AddObserver(observer);
127 }
128
RemoveObserver(Observer * observer)129 void SSLClientContext::RemoveObserver(Observer* observer) {
130 observers_.RemoveObserver(observer);
131 }
132
OnSSLContextConfigChanged()133 void SSLClientContext::OnSSLContextConfigChanged() {
134 // TODO(davidben): Should we flush |ssl_client_session_cache_| here? We flush
135 // the socket pools, but not the session cache. While BoringSSL-based servers
136 // never change version or cipher negotiation based on client-offered
137 // sessions, other servers do.
138 config_ = ssl_config_service_->GetSSLContextConfig();
139 NotifySSLConfigChanged(false /* not a cert database change */);
140 }
141
OnCertDBChanged()142 void SSLClientContext::OnCertDBChanged() {
143 // Both the trust store and client certificate store may have changed.
144 ssl_client_auth_cache_.Clear();
145 if (ssl_client_session_cache_) {
146 ssl_client_session_cache_->Flush();
147 }
148 NotifySSLConfigChanged(true /* cert database change */);
149 }
150
NotifySSLConfigChanged(bool is_cert_database_change)151 void SSLClientContext::NotifySSLConfigChanged(bool is_cert_database_change) {
152 for (Observer& observer : observers_) {
153 observer.OnSSLConfigChanged(is_cert_database_change);
154 }
155 }
156
NotifySSLConfigForServerChanged(const HostPortPair & server)157 void SSLClientContext::NotifySSLConfigForServerChanged(
158 const HostPortPair& server) {
159 for (Observer& observer : observers_) {
160 observer.OnSSLConfigForServerChanged(server);
161 }
162 }
163
164 } // namespace net
165