1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "cast/sender/channel/cast_auth_util.h"
6 
7 #include <string>
8 
9 #include "cast/common/certificate/cast_cert_validator.h"
10 #include "cast/common/certificate/cast_crl.h"
11 #include "cast/common/certificate/proto/test_suite.pb.h"
12 #include "cast/common/certificate/testing/test_helpers.h"
13 #include "cast/common/channel/proto/cast_channel.pb.h"
14 #include "gtest/gtest.h"
15 #include "platform/api/time.h"
16 #include "platform/test/paths.h"
17 #include "testing/util/read_file.h"
18 #include "util/crypto/pem_helpers.h"
19 #include "util/osp_logging.h"
20 
21 namespace openscreen {
22 namespace cast {
23 
24 // TODO(crbug.com/openscreen/90): Remove these after Chromium is migrated to
25 // openscreen::cast
26 using DeviceCertTestSuite = ::cast::certificate::DeviceCertTestSuite;
27 using VerificationResult = ::cast::certificate::VerificationResult;
28 using DeviceCertTest = ::cast::certificate::DeviceCertTest;
29 
30 namespace {
31 
32 using ::cast::channel::AuthResponse;
33 
ConvertTimeSeconds(const DateTime & time,uint64_t * seconds)34 bool ConvertTimeSeconds(const DateTime& time, uint64_t* seconds) {
35   static constexpr uint64_t kDaysPerYear = 365;
36   static constexpr uint64_t kHoursPerDay = 24;
37   static constexpr uint64_t kMinutesPerHour = 60;
38   static constexpr uint64_t kSecondsPerMinute = 60;
39 
40   static constexpr uint64_t kSecondsPerDay =
41       kSecondsPerMinute * kMinutesPerHour * kHoursPerDay;
42   static constexpr uint64_t kDaysPerQuadYear = 4 * kDaysPerYear + 1;
43   static constexpr uint64_t kDaysPerCentury =
44       kDaysPerQuadYear * 24 + kDaysPerYear * 4;
45   static constexpr uint64_t kDaysPerQuadCentury = 4 * kDaysPerCentury + 1;
46 
47   static constexpr uint64_t kDaysPerMonth[] = {
48       31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31,
49   };
50 
51   bool is_leap_year =
52       (time.year % 4 == 0 && (time.year % 100 != 0 || time.year % 400 == 0));
53   if (time.year < 1970 || time.month < 1 || time.day < 1 ||
54       time.day > (kDaysPerMonth[time.month - 1] + is_leap_year) ||
55       time.month > 12 || time.hour > 23 || time.minute > 59 ||
56       time.second > 60) {
57     return false;
58   }
59   uint64_t result = 0;
60   uint64_t year = time.year - 1970;
61   uint64_t first_two_years = year >= 2;
62   result += first_two_years * 2 * kDaysPerYear * kSecondsPerDay;
63   year -= first_two_years * 2;
64 
65   if (first_two_years) {
66     uint64_t twenty_eight_years = year >= 28;
67     result += twenty_eight_years * 7 * kDaysPerQuadYear * kSecondsPerDay;
68     year -= twenty_eight_years * 28;
69 
70     if (twenty_eight_years) {
71       uint64_t quad_centuries = year / 400;
72       result += quad_centuries * kDaysPerQuadCentury * kSecondsPerDay;
73       year -= quad_centuries * 400;
74 
75       uint64_t first_century = year >= 100;
76       result += first_century * (kDaysPerCentury + 1) * kSecondsPerDay;
77       year -= first_century * 100;
78 
79       uint64_t centuries = year / 100;
80       result += centuries * kDaysPerCentury * kSecondsPerDay;
81       year -= centuries * 100;
82     }
83 
84     uint64_t quad_years = year / 4;
85     result += quad_years * kDaysPerQuadYear * kSecondsPerDay;
86     year -= quad_years * 4;
87 
88     uint64_t first_year = year >= 1;
89     result += first_year * (kDaysPerYear + 1) * kSecondsPerDay;
90     year -= first_year;
91 
92     result += year * kDaysPerYear * kSecondsPerDay;
93     OSP_DCHECK_LE(year, 2);
94   }
95 
96   for (int i = 0; i < time.month - 1; ++i) {
97     uint64_t days = kDaysPerMonth[i];
98     result += days * kSecondsPerDay;
99   }
100   if (time.month >= 3 && is_leap_year) {
101     result += kSecondsPerDay;
102   }
103   result += (time.day - 1) * kSecondsPerDay;
104   result += time.hour * kMinutesPerHour * kSecondsPerMinute;
105   result += time.minute * kSecondsPerMinute;
106   result += time.second;
107 
108   *seconds = result;
109   return true;
110 }
111 
GetSpecificTestDataPath()112 const std::string& GetSpecificTestDataPath() {
113   static std::string data_path = GetTestDataPath() + "cast/common/certificate/";
114   return data_path;
115 }
116 
117 class CastAuthUtilTest : public ::testing::Test {
118  public:
CastAuthUtilTest()119   CastAuthUtilTest() {}
~CastAuthUtilTest()120   ~CastAuthUtilTest() override {}
121 
SetUp()122   void SetUp() override {}
123 
124  protected:
CreateAuthResponse(std::vector<uint8_t> * signed_data,::cast::channel::HashAlgorithm digest_algorithm)125   static AuthResponse CreateAuthResponse(
126       std::vector<uint8_t>* signed_data,
127       ::cast::channel::HashAlgorithm digest_algorithm) {
128     std::vector<std::string> chain = ReadCertificatesFromPemFile(
129         GetSpecificTestDataPath() + "certificates/chromecast_gen1.pem");
130     OSP_CHECK(!chain.empty());
131 
132     testing::SignatureTestData signatures = testing::ReadSignatureTestData(
133         GetSpecificTestDataPath() + "signeddata/2ZZBG9_FA8FCA3EF91A.pem");
134 
135     AuthResponse response;
136 
137     response.set_client_auth_certificate(chain[0]);
138     for (size_t i = 1; i < chain.size(); ++i) {
139       response.add_intermediate_certificate(chain[i]);
140     }
141 
142     response.set_hash_algorithm(digest_algorithm);
143     switch (digest_algorithm) {
144       case ::cast::channel::SHA1:
145         response.set_signature(
146             std::string(reinterpret_cast<const char*>(signatures.sha1.data),
147                         signatures.sha1.length));
148         break;
149       case ::cast::channel::SHA256:
150         response.set_signature(
151             std::string(reinterpret_cast<const char*>(signatures.sha256.data),
152                         signatures.sha256.length));
153         break;
154     }
155     *signed_data = std::vector<uint8_t>(
156         signatures.message.data,
157         signatures.message.data + signatures.message.length);
158 
159     return response;
160   }
161 
162   // Mangles a string by inverting the first byte.
MangleString(std::string * str)163   static void MangleString(std::string* str) { (*str)[0] = ~(*str)[0]; }
164 
165   // Mangles a vector by inverting the first byte.
MangleData(std::vector<uint8_t> * data)166   static void MangleData(std::vector<uint8_t>* data) {
167     (*data)[0] = ~(*data)[0];
168   }
169 
170   const std::string& data_path_{GetSpecificTestDataPath()};
171 };
172 
173 // Note on expiration: VerifyCredentials() depends on the system clock. In
174 // practice this shouldn't be a problem though since the certificate chain
175 // being verified doesn't expire until 2032.
TEST_F(CastAuthUtilTest,VerifySuccess)176 TEST_F(CastAuthUtilTest, VerifySuccess) {
177   std::vector<uint8_t> signed_data;
178   AuthResponse auth_response =
179       CreateAuthResponse(&signed_data, ::cast::channel::SHA256);
180   DateTime now = {};
181   ASSERT_TRUE(DateTimeFromSeconds(GetWallTimeSinceUnixEpoch().count(), &now));
182   ErrorOr<CastDeviceCertPolicy> result =
183       VerifyCredentialsForTest(auth_response, signed_data,
184                                CRLPolicy::kCrlOptional, nullptr, nullptr, now);
185   EXPECT_TRUE(result);
186   EXPECT_EQ(CastDeviceCertPolicy::kUnrestricted, result.value());
187 }
188 
TEST_F(CastAuthUtilTest,VerifyBadCA)189 TEST_F(CastAuthUtilTest, VerifyBadCA) {
190   std::vector<uint8_t> signed_data;
191   AuthResponse auth_response =
192       CreateAuthResponse(&signed_data, ::cast::channel::SHA256);
193   MangleString(auth_response.mutable_intermediate_certificate(0));
194   ErrorOr<CastDeviceCertPolicy> result =
195       VerifyCredentials(auth_response, signed_data);
196   EXPECT_FALSE(result);
197   EXPECT_EQ(Error::Code::kErrCertsParse, result.error().code());
198 }
199 
TEST_F(CastAuthUtilTest,VerifyBadClientAuthCert)200 TEST_F(CastAuthUtilTest, VerifyBadClientAuthCert) {
201   std::vector<uint8_t> signed_data;
202   AuthResponse auth_response =
203       CreateAuthResponse(&signed_data, ::cast::channel::SHA256);
204   MangleString(auth_response.mutable_client_auth_certificate());
205   ErrorOr<CastDeviceCertPolicy> result =
206       VerifyCredentials(auth_response, signed_data);
207   EXPECT_FALSE(result);
208   EXPECT_EQ(Error::Code::kErrCertsParse, result.error().code());
209 }
210 
TEST_F(CastAuthUtilTest,VerifyBadSignature)211 TEST_F(CastAuthUtilTest, VerifyBadSignature) {
212   std::vector<uint8_t> signed_data;
213   AuthResponse auth_response =
214       CreateAuthResponse(&signed_data, ::cast::channel::SHA256);
215   MangleString(auth_response.mutable_signature());
216   ErrorOr<CastDeviceCertPolicy> result =
217       VerifyCredentials(auth_response, signed_data);
218   EXPECT_FALSE(result);
219   EXPECT_EQ(Error::Code::kCastV2SignedBlobsMismatch, result.error().code());
220 }
221 
TEST_F(CastAuthUtilTest,VerifyEmptySignature)222 TEST_F(CastAuthUtilTest, VerifyEmptySignature) {
223   std::vector<uint8_t> signed_data;
224   AuthResponse auth_response =
225       CreateAuthResponse(&signed_data, ::cast::channel::SHA256);
226   auth_response.mutable_signature()->clear();
227   ErrorOr<CastDeviceCertPolicy> result =
228       VerifyCredentials(auth_response, signed_data);
229   EXPECT_FALSE(result);
230   EXPECT_EQ(Error::Code::kCastV2SignatureEmpty, result.error().code());
231 }
232 
TEST_F(CastAuthUtilTest,VerifyUnsupportedDigest)233 TEST_F(CastAuthUtilTest, VerifyUnsupportedDigest) {
234   std::vector<uint8_t> signed_data;
235   AuthResponse auth_response =
236       CreateAuthResponse(&signed_data, ::cast::channel::SHA1);
237   DateTime now = {};
238   ASSERT_TRUE(DateTimeFromSeconds(GetWallTimeSinceUnixEpoch().count(), &now));
239   ErrorOr<CastDeviceCertPolicy> result = VerifyCredentialsForTest(
240       auth_response, signed_data, CRLPolicy::kCrlOptional, nullptr, nullptr,
241       now, true);
242   EXPECT_FALSE(result);
243   EXPECT_EQ(Error::Code::kCastV2DigestUnsupported, result.error().code());
244 }
245 
TEST_F(CastAuthUtilTest,VerifyBackwardsCompatibleDigest)246 TEST_F(CastAuthUtilTest, VerifyBackwardsCompatibleDigest) {
247   std::vector<uint8_t> signed_data;
248   AuthResponse auth_response =
249       CreateAuthResponse(&signed_data, ::cast::channel::SHA1);
250   DateTime now = {};
251   ASSERT_TRUE(DateTimeFromSeconds(GetWallTimeSinceUnixEpoch().count(), &now));
252   ErrorOr<CastDeviceCertPolicy> result =
253       VerifyCredentialsForTest(auth_response, signed_data,
254                                CRLPolicy::kCrlOptional, nullptr, nullptr, now);
255   EXPECT_TRUE(result);
256 }
257 
TEST_F(CastAuthUtilTest,VerifyBadPeerCert)258 TEST_F(CastAuthUtilTest, VerifyBadPeerCert) {
259   std::vector<uint8_t> signed_data;
260   AuthResponse auth_response =
261       CreateAuthResponse(&signed_data, ::cast::channel::SHA256);
262   MangleData(&signed_data);
263   ErrorOr<CastDeviceCertPolicy> result =
264       VerifyCredentials(auth_response, signed_data);
265   EXPECT_FALSE(result);
266   EXPECT_EQ(Error::Code::kCastV2SignedBlobsMismatch, result.error().code());
267 }
268 
TEST_F(CastAuthUtilTest,VerifySenderNonceMatch)269 TEST_F(CastAuthUtilTest, VerifySenderNonceMatch) {
270   AuthContext context = AuthContext::Create();
271   const Error result = context.VerifySenderNonce(context.nonce(), true);
272   EXPECT_TRUE(result.ok());
273 }
274 
TEST_F(CastAuthUtilTest,VerifySenderNonceMismatch)275 TEST_F(CastAuthUtilTest, VerifySenderNonceMismatch) {
276   AuthContext context = AuthContext::Create();
277   std::string received_nonce = "test2";
278   EXPECT_NE(received_nonce, context.nonce());
279   ErrorOr<CastDeviceCertPolicy> result =
280       context.VerifySenderNonce(received_nonce, true);
281   EXPECT_FALSE(result);
282   EXPECT_EQ(Error::Code::kCastV2SenderNonceMismatch, result.error().code());
283 }
284 
TEST_F(CastAuthUtilTest,VerifySenderNonceMissing)285 TEST_F(CastAuthUtilTest, VerifySenderNonceMissing) {
286   AuthContext context = AuthContext::Create();
287   std::string received_nonce;
288   EXPECT_FALSE(context.nonce().empty());
289   ErrorOr<CastDeviceCertPolicy> result =
290       context.VerifySenderNonce(received_nonce, true);
291   EXPECT_FALSE(result);
292   EXPECT_EQ(Error::Code::kCastV2SenderNonceMismatch, result.error().code());
293 }
294 
TEST_F(CastAuthUtilTest,VerifyTLSCertificateSuccess)295 TEST_F(CastAuthUtilTest, VerifyTLSCertificateSuccess) {
296   std::vector<std::string> tls_cert_der = ReadCertificatesFromPemFile(
297       data_path_ + "certificates/test_tls_cert.pem");
298   std::string& der_cert = tls_cert_der[0];
299   const uint8_t* data = (const uint8_t*)der_cert.data();
300   X509* tls_cert = d2i_X509(nullptr, &data, der_cert.size());
301   DateTime not_before;
302   DateTime not_after;
303   ASSERT_TRUE(GetCertValidTimeRange(tls_cert, &not_before, &not_after));
304   uint64_t x;
305   ASSERT_TRUE(ConvertTimeSeconds(not_before, &x));
306   std::chrono::seconds s(x);
307 
308   const Error result = VerifyTLSCertificateValidity(tls_cert, s);
309   EXPECT_TRUE(result.ok());
310   X509_free(tls_cert);
311 }
312 
TEST_F(CastAuthUtilTest,VerifyTLSCertificateTooEarly)313 TEST_F(CastAuthUtilTest, VerifyTLSCertificateTooEarly) {
314   std::vector<std::string> tls_cert_der = ReadCertificatesFromPemFile(
315       data_path_ + "certificates/test_tls_cert.pem");
316   std::string& der_cert = tls_cert_der[0];
317   const uint8_t* data = (const uint8_t*)der_cert.data();
318   X509* tls_cert = d2i_X509(nullptr, &data, der_cert.size());
319   DateTime not_before;
320   DateTime not_after;
321   ASSERT_TRUE(GetCertValidTimeRange(tls_cert, &not_before, &not_after));
322   uint64_t x;
323   ASSERT_TRUE(ConvertTimeSeconds(not_before, &x));
324   std::chrono::seconds s(x - 1);
325 
326   ErrorOr<CastDeviceCertPolicy> result =
327       VerifyTLSCertificateValidity(tls_cert, s);
328   EXPECT_FALSE(result);
329   EXPECT_EQ(Error::Code::kCastV2TlsCertValidStartDateInFuture,
330             result.error().code());
331   X509_free(tls_cert);
332 }
333 
TEST_F(CastAuthUtilTest,VerifyTLSCertificateTooLate)334 TEST_F(CastAuthUtilTest, VerifyTLSCertificateTooLate) {
335   std::vector<std::string> tls_cert_der = ReadCertificatesFromPemFile(
336       data_path_ + "certificates/test_tls_cert.pem");
337   std::string& der_cert = tls_cert_der[0];
338   const uint8_t* data = (const uint8_t*)der_cert.data();
339   X509* tls_cert = d2i_X509(nullptr, &data, der_cert.size());
340   DateTime not_before;
341   DateTime not_after;
342   ASSERT_TRUE(GetCertValidTimeRange(tls_cert, &not_before, &not_after));
343   uint64_t x;
344   ASSERT_TRUE(ConvertTimeSeconds(not_after, &x));
345   std::chrono::seconds s(x + 2);
346 
347   ErrorOr<CastDeviceCertPolicy> result =
348       VerifyTLSCertificateValidity(tls_cert, s);
349   EXPECT_FALSE(result);
350   EXPECT_EQ(Error::Code::kCastV2TlsCertExpired, result.error().code());
351   X509_free(tls_cert);
352 }
353 
354 // Indicates the expected result of test step's verification.
355 enum TestStepResult {
356   RESULT_SUCCESS,
357   RESULT_FAIL,
358 };
359 
360 // Verifies that the certificate chain provided is not revoked according to
361 // the provided Cast CRL at |verification_time|.
362 // The provided CRL is verified at |verification_time|.
363 // If |crl_required| is set, then a valid Cast CRL must be provided.
364 // Otherwise, a missing CRL is be ignored.
TestVerifyRevocation(const std::vector<std::string> & certificate_chain,const std::string & crl_bundle,const DateTime & verification_time,bool crl_required,TrustStore * cast_trust_store,TrustStore * crl_trust_store)365 ErrorOr<CastDeviceCertPolicy> TestVerifyRevocation(
366     const std::vector<std::string>& certificate_chain,
367     const std::string& crl_bundle,
368     const DateTime& verification_time,
369     bool crl_required,
370     TrustStore* cast_trust_store,
371     TrustStore* crl_trust_store) {
372   AuthResponse response;
373 
374   if (certificate_chain.size() > 0) {
375     response.set_client_auth_certificate(certificate_chain[0]);
376     for (size_t i = 1; i < certificate_chain.size(); ++i) {
377       response.add_intermediate_certificate(certificate_chain[i]);
378     }
379   }
380 
381   response.set_crl(crl_bundle);
382 
383   CRLPolicy crl_policy = CRLPolicy::kCrlRequired;
384   if (!crl_required && crl_bundle.empty())
385     crl_policy = CRLPolicy::kCrlOptional;
386   ErrorOr<CastDeviceCertPolicy> result = VerifyCredentialsForTest(
387       response, std::vector<uint8_t>(), crl_policy, cast_trust_store,
388       crl_trust_store, verification_time);
389   // This test doesn't set the signature so it will just fail there.
390   EXPECT_FALSE(result);
391   return result;
392 }
393 
394 // Runs a single test case.
RunTest(const DeviceCertTest & test_case)395 bool RunTest(const DeviceCertTest& test_case) {
396   TrustStore crl_trust_store;
397   TrustStore cast_trust_store;
398   if (test_case.use_test_trust_anchors()) {
399     crl_trust_store = TrustStore::CreateInstanceFromPemFile(
400         GetSpecificTestDataPath() + "certificates/cast_crl_test_root_ca.pem");
401     cast_trust_store = TrustStore::CreateInstanceFromPemFile(
402         GetSpecificTestDataPath() + "certificates/cast_test_root_ca.pem");
403 
404     EXPECT_FALSE(crl_trust_store.certs.empty());
405     EXPECT_FALSE(cast_trust_store.certs.empty());
406   }
407 
408   std::vector<std::string> certificate_chain;
409   for (auto const& cert : test_case.der_cert_path()) {
410     certificate_chain.push_back(cert);
411   }
412 
413   // CastAuthUtil verifies the CRL at the same time as the certificate.
414   DateTime verification_time;
415   uint64_t cert_verify_time = test_case.cert_verification_time_seconds();
416   if (!cert_verify_time) {
417     cert_verify_time = test_case.crl_verification_time_seconds();
418   }
419   OSP_DCHECK(DateTimeFromSeconds(cert_verify_time, &verification_time));
420 
421   std::string crl_bundle = test_case.crl_bundle();
422   ErrorOr<CastDeviceCertPolicy> result(CastDeviceCertPolicy::kUnrestricted);
423   switch (test_case.expected_result()) {
424     case ::cast::certificate::PATH_VERIFICATION_FAILED:
425       result =
426           TestVerifyRevocation(certificate_chain, crl_bundle, verification_time,
427                                false, &cast_trust_store, &cast_trust_store);
428       EXPECT_EQ(result.error().code(),
429                 Error::Code::kCastV2CertNotSignedByTrustedCa);
430       return result.error().code() ==
431              Error::Code::kCastV2CertNotSignedByTrustedCa;
432     case ::cast::certificate::CRL_VERIFICATION_FAILED:
433     // Fall-through intended.
434     case ::cast::certificate::REVOCATION_CHECK_FAILED_WITHOUT_CRL:
435       result =
436           TestVerifyRevocation(certificate_chain, crl_bundle, verification_time,
437                                true, &cast_trust_store, &cast_trust_store);
438       EXPECT_EQ(result.error().code(), Error::Code::kErrCrlInvalid);
439       return result.error().code() == Error::Code::kErrCrlInvalid;
440     case ::cast::certificate::CRL_EXPIRED_AFTER_INITIAL_VERIFICATION:
441       // By-pass this test because CRL is always verified at the time the
442       // certificate is verified.
443       return true;
444     case ::cast::certificate::REVOCATION_CHECK_FAILED:
445       result =
446           TestVerifyRevocation(certificate_chain, crl_bundle, verification_time,
447                                true, &cast_trust_store, &cast_trust_store);
448       EXPECT_EQ(result.error().code(), Error::Code::kErrCertsRevoked);
449       return result.error().code() == Error::Code::kErrCertsRevoked;
450     case ::cast::certificate::SUCCESS:
451       result =
452           TestVerifyRevocation(certificate_chain, crl_bundle, verification_time,
453                                false, &cast_trust_store, &cast_trust_store);
454       EXPECT_EQ(result.error().code(), Error::Code::kCastV2SignedBlobsMismatch);
455       return result.error().code() == Error::Code::kCastV2SignedBlobsMismatch;
456     case ::cast::certificate::UNSPECIFIED:
457       return false;
458   }
459   return false;
460 }
461 
462 // Parses the provided test suite provided in wire-format proto.
463 // Each test contains the inputs and the expected output.
464 // To see the description of the test, execute the test.
465 // These tests are generated by a test generator in google3.
RunTestSuite(const std::string & test_suite_file_name)466 void RunTestSuite(const std::string& test_suite_file_name) {
467   std::string testsuite_raw = ReadEntireFileToString(test_suite_file_name);
468   DeviceCertTestSuite test_suite;
469   EXPECT_TRUE(test_suite.ParseFromString(testsuite_raw));
470   uint16_t successes = 0;
471 
472   for (auto const& test_case : test_suite.tests()) {
473     bool result = RunTest(test_case);
474     EXPECT_TRUE(result) << test_case.description();
475     successes += result;
476   }
477   OSP_LOG_IF(ERROR, successes != test_suite.tests().size())
478       << "successes: " << successes
479       << ", failures: " << (test_suite.tests().size() - successes);
480 }
481 
TEST_F(CastAuthUtilTest,CRLTestSuite)482 TEST_F(CastAuthUtilTest, CRLTestSuite) {
483   RunTestSuite("testsuite/testsuite1.pb");
484 }
485 
486 }  // namespace
487 }  // namespace cast
488 }  // namespace openscreen
489