1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this file,
5  * You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #include "nss.h"
8 #include "secerr.h"
9 #include "ssl.h"
10 #include "ssl3prot.h"
11 #include "sslerr.h"
12 #include "sslproto.h"
13 
14 #include "gtest_utils.h"
15 #include "nss_scoped_ptrs.h"
16 #include "tls_connect.h"
17 #include "tls_filter.h"
18 #include "tls_parser.h"
19 
20 #include <iostream>
21 
22 namespace nss_test {
23 
GetSSLVersionString(uint16_t v)24 std::string GetSSLVersionString(uint16_t v) {
25   switch (v) {
26     case SSL_LIBRARY_VERSION_3_0:
27       return "ssl3";
28     case SSL_LIBRARY_VERSION_TLS_1_0:
29       return "tls1.0";
30     case SSL_LIBRARY_VERSION_TLS_1_1:
31       return "tls1.1";
32     case SSL_LIBRARY_VERSION_TLS_1_2:
33       return "tls1.2";
34     case SSL_LIBRARY_VERSION_TLS_1_3:
35       return "tls1.3";
36     case SSL_LIBRARY_VERSION_NONE:
37       return "NONE";
38   }
39   if (v < SSL_LIBRARY_VERSION_3_0) {
40     return "undefined-too-low";
41   }
42   return "undefined-too-high";
43 }
44 
operator <<(std::ostream & stream,const SSLVersionRange & vr)45 inline std::ostream& operator<<(std::ostream& stream,
46                                 const SSLVersionRange& vr) {
47   return stream << GetSSLVersionString(vr.min) << ","
48                 << GetSSLVersionString(vr.max);
49 }
50 
51 class VersionRangeWithLabel {
52  public:
VersionRangeWithLabel(const std::string & txt,const SSLVersionRange & vr)53   VersionRangeWithLabel(const std::string& txt, const SSLVersionRange& vr)
54       : label_(txt), vr_(vr) {}
VersionRangeWithLabel(const std::string & txt,uint16_t start,uint16_t end)55   VersionRangeWithLabel(const std::string& txt, uint16_t start, uint16_t end)
56       : label_(txt) {
57     vr_.min = start;
58     vr_.max = end;
59   }
VersionRangeWithLabel(const std::string & label)60   VersionRangeWithLabel(const std::string& label) : label_(label) {
61     vr_.min = vr_.max = SSL_LIBRARY_VERSION_NONE;
62   }
63 
WriteStream(std::ostream & stream) const64   void WriteStream(std::ostream& stream) const {
65     stream << " " << label_ << ": " << vr_;
66   }
67 
min() const68   uint16_t min() const { return vr_.min; }
max() const69   uint16_t max() const { return vr_.max; }
range() const70   SSLVersionRange range() const { return vr_; }
71 
72  private:
73   std::string label_;
74   SSLVersionRange vr_;
75 };
76 
operator <<(std::ostream & stream,const VersionRangeWithLabel & vrwl)77 inline std::ostream& operator<<(std::ostream& stream,
78                                 const VersionRangeWithLabel& vrwl) {
79   vrwl.WriteStream(stream);
80   return stream;
81 }
82 
83 typedef std::tuple<SSLProtocolVariant,  // variant
84                    uint16_t,            // policy min
85                    uint16_t,            // policy max
86                    uint16_t,            // input min
87                    uint16_t>            // input max
88     PolicyVersionRangeInput;
89 
90 class TestPolicyVersionRange
91     : public TlsConnectTestBase,
92       public ::testing::WithParamInterface<PolicyVersionRangeInput> {
93  public:
TestPolicyVersionRange()94   TestPolicyVersionRange()
95       : TlsConnectTestBase(std::get<0>(GetParam()), 0),
96         variant_(std::get<0>(GetParam())),
97         policy_("policy", std::get<1>(GetParam()), std::get<2>(GetParam())),
98         input_("input", std::get<3>(GetParam()), std::get<4>(GetParam())),
99         library_("supported-by-library",
100                  ((variant_ == ssl_variant_stream)
101                       ? SSL_LIBRARY_VERSION_MIN_SUPPORTED_STREAM
102                       : SSL_LIBRARY_VERSION_MIN_SUPPORTED_DATAGRAM),
103                  SSL_LIBRARY_VERSION_MAX_SUPPORTED) {
104     TlsConnectTestBase::SkipVersionChecks();
105   }
106 
SetPolicy(const SSLVersionRange & policy)107   void SetPolicy(const SSLVersionRange& policy) {
108     NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, 0);
109 
110     SECStatus rv;
111     rv = NSS_OptionSet(NSS_TLS_VERSION_MIN_POLICY, policy.min);
112     ASSERT_EQ(SECSuccess, rv);
113     rv = NSS_OptionSet(NSS_TLS_VERSION_MAX_POLICY, policy.max);
114     ASSERT_EQ(SECSuccess, rv);
115     rv = NSS_OptionSet(NSS_DTLS_VERSION_MIN_POLICY, policy.min);
116     ASSERT_EQ(SECSuccess, rv);
117     rv = NSS_OptionSet(NSS_DTLS_VERSION_MAX_POLICY, policy.max);
118     ASSERT_EQ(SECSuccess, rv);
119   }
120 
CreateDummySocket(std::shared_ptr<DummyPrSocket> * dummy_socket,ScopedPRFileDesc * ssl_fd)121   void CreateDummySocket(std::shared_ptr<DummyPrSocket>* dummy_socket,
122                          ScopedPRFileDesc* ssl_fd) {
123     (*dummy_socket).reset(new DummyPrSocket("dummy", variant_));
124     *ssl_fd = (*dummy_socket)->CreateFD();
125     if (variant_ == ssl_variant_stream) {
126       SSL_ImportFD(nullptr, ssl_fd->get());
127     } else {
128       DTLS_ImportFD(nullptr, ssl_fd->get());
129     }
130   }
131 
GetOverlap(const SSLVersionRange & r1,const SSLVersionRange & r2,SSLVersionRange * overlap)132   bool GetOverlap(const SSLVersionRange& r1, const SSLVersionRange& r2,
133                   SSLVersionRange* overlap) {
134     if (r1.min == SSL_LIBRARY_VERSION_NONE ||
135         r1.max == SSL_LIBRARY_VERSION_NONE ||
136         r2.min == SSL_LIBRARY_VERSION_NONE ||
137         r2.max == SSL_LIBRARY_VERSION_NONE) {
138       return false;
139     }
140 
141     SSLVersionRange temp;
142     temp.min = PR_MAX(r1.min, r2.min);
143     temp.max = PR_MIN(r1.max, r2.max);
144 
145     if (temp.min > temp.max) {
146       return false;
147     }
148 
149     *overlap = temp;
150     return true;
151   }
152 
IsValidInputForVersionRangeSet(SSLVersionRange * expectedEffectiveRange)153   bool IsValidInputForVersionRangeSet(SSLVersionRange* expectedEffectiveRange) {
154     if (input_.min() <= SSL_LIBRARY_VERSION_3_0 &&
155         input_.max() >= SSL_LIBRARY_VERSION_TLS_1_3) {
156       // This is always invalid input, independent of policy
157       return false;
158     }
159 
160     if (input_.min() < library_.min() || input_.max() > library_.max() ||
161         input_.min() > input_.max()) {
162       // Asking for unsupported ranges is invalid input for VersionRangeSet
163       // APIs, regardless of overlap.
164       return false;
165     }
166 
167     SSLVersionRange overlap_with_library;
168     if (!GetOverlap(input_.range(), library_.range(), &overlap_with_library)) {
169       return false;
170     }
171 
172     SSLVersionRange overlap_with_library_and_policy;
173     if (!GetOverlap(overlap_with_library, policy_.range(),
174                     &overlap_with_library_and_policy)) {
175       return false;
176     }
177 
178     RemoveConflictingVersions(variant_, &overlap_with_library_and_policy);
179     *expectedEffectiveRange = overlap_with_library_and_policy;
180     return true;
181   }
182 
RemoveConflictingVersions(SSLProtocolVariant variant,SSLVersionRange * r)183   void RemoveConflictingVersions(SSLProtocolVariant variant,
184                                  SSLVersionRange* r) {
185     ASSERT_TRUE(r != nullptr);
186     if (r->max >= SSL_LIBRARY_VERSION_TLS_1_3 &&
187         r->min < SSL_LIBRARY_VERSION_TLS_1_0) {
188       r->min = SSL_LIBRARY_VERSION_TLS_1_0;
189     }
190   }
191 
SetUp()192   void SetUp() override {
193     TlsConnectTestBase::SetUp();
194     SetPolicy(policy_.range());
195   }
196 
TearDown()197   void TearDown() override {
198     TlsConnectTestBase::TearDown();
199     saved_version_policy_.RestoreOriginalPolicy();
200   }
201 
202  protected:
203   class VersionPolicy {
204    public:
VersionPolicy()205     VersionPolicy() { SaveOriginalPolicy(); }
206 
RestoreOriginalPolicy()207     void RestoreOriginalPolicy() {
208       SECStatus rv;
209       rv = NSS_OptionSet(NSS_TLS_VERSION_MIN_POLICY, saved_min_tls_);
210       ASSERT_EQ(SECSuccess, rv);
211       rv = NSS_OptionSet(NSS_TLS_VERSION_MAX_POLICY, saved_max_tls_);
212       ASSERT_EQ(SECSuccess, rv);
213       rv = NSS_OptionSet(NSS_DTLS_VERSION_MIN_POLICY, saved_min_dtls_);
214       ASSERT_EQ(SECSuccess, rv);
215       rv = NSS_OptionSet(NSS_DTLS_VERSION_MAX_POLICY, saved_max_dtls_);
216       ASSERT_EQ(SECSuccess, rv);
217     }
218 
219    private:
SaveOriginalPolicy()220     void SaveOriginalPolicy() {
221       SECStatus rv;
222       rv = NSS_OptionGet(NSS_TLS_VERSION_MIN_POLICY, &saved_min_tls_);
223       ASSERT_EQ(SECSuccess, rv);
224       rv = NSS_OptionGet(NSS_TLS_VERSION_MAX_POLICY, &saved_max_tls_);
225       ASSERT_EQ(SECSuccess, rv);
226       rv = NSS_OptionGet(NSS_DTLS_VERSION_MIN_POLICY, &saved_min_dtls_);
227       ASSERT_EQ(SECSuccess, rv);
228       rv = NSS_OptionGet(NSS_DTLS_VERSION_MAX_POLICY, &saved_max_dtls_);
229       ASSERT_EQ(SECSuccess, rv);
230     }
231 
232     int32_t saved_min_tls_;
233     int32_t saved_max_tls_;
234     int32_t saved_min_dtls_;
235     int32_t saved_max_dtls_;
236   };
237 
238   VersionPolicy saved_version_policy_;
239 
240   SSLProtocolVariant variant_;
241   const VersionRangeWithLabel policy_;
242   const VersionRangeWithLabel input_;
243   const VersionRangeWithLabel library_;
244 };
245 
246 static const uint16_t kExpandedVersionsArr[] = {
247     /* clang-format off */
248     SSL_LIBRARY_VERSION_3_0 - 1,
249     SSL_LIBRARY_VERSION_3_0,
250     SSL_LIBRARY_VERSION_TLS_1_0,
251     SSL_LIBRARY_VERSION_TLS_1_1,
252     SSL_LIBRARY_VERSION_TLS_1_2,
253 #ifndef NSS_DISABLE_TLS_1_3
254     SSL_LIBRARY_VERSION_TLS_1_3,
255 #endif
256     SSL_LIBRARY_VERSION_MAX_SUPPORTED + 1
257     /* clang-format on */
258 };
259 static ::testing::internal::ParamGenerator<uint16_t> kExpandedVersions =
260     ::testing::ValuesIn(kExpandedVersionsArr);
261 
TEST_P(TestPolicyVersionRange,TestAllTLSVersionsAndPolicyCombinations)262 TEST_P(TestPolicyVersionRange, TestAllTLSVersionsAndPolicyCombinations) {
263   ASSERT_TRUE(variant_ == ssl_variant_stream ||
264               variant_ == ssl_variant_datagram)
265       << "testing unsupported ssl variant";
266 
267   std::cerr << "testing: " << variant_ << policy_ << input_ << library_
268             << std::endl;
269 
270   SSLVersionRange supported_range;
271   SECStatus rv = SSL_VersionRangeGetSupported(variant_, &supported_range);
272   VersionRangeWithLabel supported("SSL_VersionRangeGetSupported",
273                                   supported_range);
274 
275   std::cerr << supported << std::endl;
276 
277   std::shared_ptr<DummyPrSocket> dummy_socket;
278   ScopedPRFileDesc ssl_fd;
279   CreateDummySocket(&dummy_socket, &ssl_fd);
280 
281   SECStatus rv_socket;
282   SSLVersionRange overlap_policy_and_lib;
283   if (!GetOverlap(policy_.range(), library_.range(), &overlap_policy_and_lib)) {
284     EXPECT_EQ(SECFailure, rv)
285         << "expected SSL_VersionRangeGetSupported to fail with invalid policy";
286 
287     SSLVersionRange enabled_range;
288     rv = SSL_VersionRangeGetDefault(variant_, &enabled_range);
289     EXPECT_EQ(SECFailure, rv)
290         << "expected SSL_VersionRangeGetDefault to fail with invalid policy";
291 
292     SSLVersionRange enabled_range_on_socket;
293     rv_socket = SSL_VersionRangeGet(ssl_fd.get(), &enabled_range_on_socket);
294     EXPECT_EQ(SECFailure, rv_socket)
295         << "expected SSL_VersionRangeGet to fail with invalid policy";
296 
297     ConnectExpectFail();
298     return;
299   }
300 
301   EXPECT_EQ(SECSuccess, rv)
302       << "expected SSL_VersionRangeGetSupported to succeed with valid policy";
303 
304   EXPECT_TRUE(supported_range.min != SSL_LIBRARY_VERSION_NONE &&
305               supported_range.max != SSL_LIBRARY_VERSION_NONE)
306       << "expected SSL_VersionRangeGetSupported to return real values with "
307          "valid policy";
308 
309   RemoveConflictingVersions(variant_, &overlap_policy_and_lib);
310   VersionRangeWithLabel overlap_info("overlap", overlap_policy_and_lib);
311 
312   EXPECT_TRUE(supported_range == overlap_policy_and_lib)
313       << "expected range from GetSupported to be identical with calculated "
314          "overlap "
315       << overlap_info;
316 
317   // We don't know which versions are "enabled by default" by the library,
318   // therefore we don't know if there's overlap between the default
319   // and the policy, and therefore, we don't if TLS connections should
320   // be successful or fail in this combination.
321   // Therefore we don't test if we can connect, without having configured a
322   // version range explicitly.
323 
324   // Now start testing with supplied input.
325 
326   SSLVersionRange expected_effective_range;
327   bool is_valid_input =
328       IsValidInputForVersionRangeSet(&expected_effective_range);
329 
330   SSLVersionRange temp_input = input_.range();
331   rv = SSL_VersionRangeSetDefault(variant_, &temp_input);
332   rv_socket = SSL_VersionRangeSet(ssl_fd.get(), &temp_input);
333 
334   if (!is_valid_input) {
335     EXPECT_EQ(SECFailure, rv)
336         << "expected failure return from SSL_VersionRangeSetDefault";
337 
338     EXPECT_EQ(SECFailure, rv_socket)
339         << "expected failure return from SSL_VersionRangeSet";
340     return;
341   }
342 
343   EXPECT_EQ(SECSuccess, rv)
344       << "expected successful return from SSL_VersionRangeSetDefault";
345 
346   EXPECT_EQ(SECSuccess, rv_socket)
347       << "expected successful return from SSL_VersionRangeSet";
348 
349   SSLVersionRange effective;
350   SSLVersionRange effective_socket;
351 
352   rv = SSL_VersionRangeGetDefault(variant_, &effective);
353   EXPECT_EQ(SECSuccess, rv)
354       << "expected successful return from SSL_VersionRangeGetDefault";
355 
356   rv_socket = SSL_VersionRangeGet(ssl_fd.get(), &effective_socket);
357   EXPECT_EQ(SECSuccess, rv_socket)
358       << "expected successful return from SSL_VersionRangeGet";
359 
360   VersionRangeWithLabel expected_info("expectation", expected_effective_range);
361   VersionRangeWithLabel effective_info("effectively-enabled", effective);
362 
363   EXPECT_TRUE(expected_effective_range == effective)
364       << "range returned by SSL_VersionRangeGetDefault doesn't match "
365          "expectation: "
366       << expected_info << effective_info;
367 
368   EXPECT_TRUE(expected_effective_range == effective_socket)
369       << "range returned by SSL_VersionRangeGet doesn't match "
370          "expectation: "
371       << expected_info << effective_info;
372 
373   // Because we found overlap between policy and supported versions,
374   // and because we have used SetDefault to enable at least one version,
375   // it should be possible to execute an SSL/TLS connection.
376   Connect();
377 }
378 
379 INSTANTIATE_TEST_SUITE_P(TLSVersionRanges, TestPolicyVersionRange,
380                          ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
381                                             kExpandedVersions,
382                                             kExpandedVersions,
383                                             kExpandedVersions,
384                                             kExpandedVersions));
385 }  // namespace nss_test
386