1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "cast/sender/channel/sender_socket_factory.h"
6 
7 #include "cast/common/channel/cast_socket.h"
8 #include "cast/common/channel/proto/cast_channel.pb.h"
9 #include "cast/sender/channel/message_util.h"
10 #include "platform/base/tls_connect_options.h"
11 #include "util/crypto/certificate_utils.h"
12 #include "util/logging.h"
13 
14 using ::cast::channel::CastMessage;
15 
16 namespace openscreen {
17 namespace cast {
18 
operator <(const std::unique_ptr<SenderSocketFactory::PendingAuth> & a,int b)19 bool operator<(const std::unique_ptr<SenderSocketFactory::PendingAuth>& a,
20                int b) {
21   return a && a->socket->socket_id() < b;
22 }
23 
operator <(int a,const std::unique_ptr<SenderSocketFactory::PendingAuth> & b)24 bool operator<(int a,
25                const std::unique_ptr<SenderSocketFactory::PendingAuth>& b) {
26   return b && a < b->socket->socket_id();
27 }
28 
SenderSocketFactory(Client * client,TaskRunner * task_runner)29 SenderSocketFactory::SenderSocketFactory(Client* client,
30                                          TaskRunner* task_runner)
31     : client_(client), task_runner_(task_runner) {
32   OSP_DCHECK(client);
33   OSP_DCHECK(task_runner);
34 }
35 
~SenderSocketFactory()36 SenderSocketFactory::~SenderSocketFactory() {
37   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
38 }
39 
set_factory(TlsConnectionFactory * factory)40 void SenderSocketFactory::set_factory(TlsConnectionFactory* factory) {
41   OSP_DCHECK(factory);
42   factory_ = factory;
43 }
44 
Connect(const IPEndpoint & endpoint,DeviceMediaPolicy media_policy,CastSocket::Client * client)45 void SenderSocketFactory::Connect(const IPEndpoint& endpoint,
46                                   DeviceMediaPolicy media_policy,
47                                   CastSocket::Client* client) {
48   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
49   OSP_DCHECK(client);
50   auto it = FindPendingConnection(endpoint);
51   if (it == pending_connections_.end()) {
52     pending_connections_.emplace_back(
53         PendingConnection{endpoint, media_policy, client});
54     factory_->Connect(endpoint, TlsConnectOptions{true});
55   }
56 }
57 
OnAccepted(TlsConnectionFactory * factory,std::vector<uint8_t> der_x509_peer_cert,std::unique_ptr<TlsConnection> connection)58 void SenderSocketFactory::OnAccepted(
59     TlsConnectionFactory* factory,
60     std::vector<uint8_t> der_x509_peer_cert,
61     std::unique_ptr<TlsConnection> connection) {
62   OSP_NOTREACHED() << "This factory is connect-only.";
63 }
64 
OnConnected(TlsConnectionFactory * factory,std::vector<uint8_t> der_x509_peer_cert,std::unique_ptr<TlsConnection> connection)65 void SenderSocketFactory::OnConnected(
66     TlsConnectionFactory* factory,
67     std::vector<uint8_t> der_x509_peer_cert,
68     std::unique_ptr<TlsConnection> connection) {
69   const IPEndpoint& endpoint = connection->GetRemoteEndpoint();
70   auto it = FindPendingConnection(endpoint);
71   if (it == pending_connections_.end()) {
72     OSP_DLOG_ERROR << "TLS connection succeeded for unknown endpoint: "
73                    << endpoint;
74     return;
75   }
76   DeviceMediaPolicy media_policy = it->media_policy;
77   CastSocket::Client* client = it->client;
78   pending_connections_.erase(it);
79 
80   ErrorOr<bssl::UniquePtr<X509>> peer_cert =
81       ImportCertificate(der_x509_peer_cert.data(), der_x509_peer_cert.size());
82   if (!peer_cert) {
83     client_->OnError(this, endpoint, peer_cert.error());
84     return;
85   }
86 
87   auto socket =
88       MakeSerialDelete<CastSocket>(task_runner_, std::move(connection), this);
89   pending_auth_.emplace_back(
90       new PendingAuth{endpoint, media_policy, std::move(socket), client,
91                       AuthContext::Create(), std::move(peer_cert.value())});
92   PendingAuth& pending = *pending_auth_.back();
93 
94   CastMessage auth_challenge = CreateAuthChallengeMessage(pending.auth_context);
95   Error error = pending.socket->SendMessage(auth_challenge);
96   if (!error.ok()) {
97     pending_auth_.pop_back();
98     client_->OnError(this, endpoint, error);
99   }
100 }
101 
OnConnectionFailed(TlsConnectionFactory * factory,const IPEndpoint & remote_address)102 void SenderSocketFactory::OnConnectionFailed(TlsConnectionFactory* factory,
103                                              const IPEndpoint& remote_address) {
104   auto it = FindPendingConnection(remote_address);
105   if (it == pending_connections_.end()) {
106     OSP_DVLOG << "OnConnectionFailed reported for untracked address: "
107               << remote_address;
108     return;
109   }
110   pending_connections_.erase(it);
111   client_->OnError(this, remote_address, Error::Code::kConnectionFailed);
112 }
113 
OnError(TlsConnectionFactory * factory,Error error)114 void SenderSocketFactory::OnError(TlsConnectionFactory* factory, Error error) {
115   std::vector<PendingConnection> connections;
116   pending_connections_.swap(connections);
117   for (const PendingConnection& pending : connections) {
118     client_->OnError(this, pending.endpoint, error);
119   }
120 }
121 
122 std::vector<SenderSocketFactory::PendingConnection>::iterator
FindPendingConnection(const IPEndpoint & endpoint)123 SenderSocketFactory::FindPendingConnection(const IPEndpoint& endpoint) {
124   return std::find_if(pending_connections_.begin(), pending_connections_.end(),
125                       [&endpoint](const PendingConnection& pending) {
126                         return pending.endpoint == endpoint;
127                       });
128 }
129 
OnError(CastSocket * socket,Error error)130 void SenderSocketFactory::OnError(CastSocket* socket, Error error) {
131   auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(),
132                          [id = socket->socket_id()](
133                              const std::unique_ptr<PendingAuth>& pending_auth) {
134                            return pending_auth->socket->socket_id() == id;
135                          });
136   if (it == pending_auth_.end()) {
137     OSP_DLOG_ERROR << "Got error for unknown pending socket";
138     return;
139   }
140   IPEndpoint endpoint = (*it)->endpoint;
141   pending_auth_.erase(it);
142   client_->OnError(this, endpoint, error);
143 }
144 
OnMessage(CastSocket * socket,CastMessage message)145 void SenderSocketFactory::OnMessage(CastSocket* socket, CastMessage message) {
146   auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(),
147                          [id = socket->socket_id()](
148                              const std::unique_ptr<PendingAuth>& pending_auth) {
149                            return pending_auth->socket->socket_id() == id;
150                          });
151   if (it == pending_auth_.end()) {
152     OSP_DLOG_ERROR << "Got message for unknown pending socket";
153     return;
154   }
155 
156   std::unique_ptr<PendingAuth> pending = std::move(*it);
157   pending_auth_.erase(it);
158   if (!IsAuthMessage(message)) {
159     client_->OnError(this, pending->endpoint,
160                      Error::Code::kCastV2AuthenticationError);
161     return;
162   }
163 
164   ErrorOr<CastDeviceCertPolicy> policy_or_error = AuthenticateChallengeReply(
165       message, pending->peer_cert.get(), pending->auth_context);
166   if (policy_or_error.is_error()) {
167     OSP_DLOG_WARN << "Authentication failed for " << pending->endpoint
168                   << " with error: " << policy_or_error.error();
169     client_->OnError(this, pending->endpoint, policy_or_error.error());
170     return;
171   }
172 
173   if (policy_or_error.value() == CastDeviceCertPolicy::kAudioOnly &&
174       pending->media_policy == DeviceMediaPolicy::kIncludesVideo) {
175     client_->OnError(this, pending->endpoint,
176                      Error::Code::kCastV2ChannelPolicyMismatch);
177     return;
178   }
179   pending->socket->set_audio_only(policy_or_error.value() ==
180                                   CastDeviceCertPolicy::kAudioOnly);
181 
182   pending->socket->SetClient(pending->client);
183   client_->OnConnected(this, pending->endpoint,
184                        std::unique_ptr<CastSocket>(pending->socket.release()));
185 }
186 
187 }  // namespace cast
188 }  // namespace openscreen
189