1
2 /* vim: set ts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
5 * You can obtain one at http://mozilla.org/MPL/2.0/. */
6
7 // Original author: ekr@rtfm.com
8
9 #include <iostream>
10 #include <string>
11 #include <map>
12 #include <algorithm>
13 #include <functional>
14
15 #ifdef XP_MACOSX
16 // ensure that Apple Security kit enum goes before "sslproto.h"
17 # include <CoreFoundation/CFAvailability.h>
18 # include <Security/CipherSuite.h>
19 #endif
20
21 #include "mozilla/UniquePtr.h"
22
23 #include "sigslot.h"
24
25 #include "logging.h"
26 #include "ssl.h"
27 #include "sslexp.h"
28 #include "sslproto.h"
29
30 #include "nsThreadUtils.h"
31 #include "nsXPCOM.h"
32
33 #include "mediapacket.h"
34 #include "dtlsidentity.h"
35 #include "nricectx.h"
36 #include "nricemediastream.h"
37 #include "transportflow.h"
38 #include "transportlayer.h"
39 #include "transportlayerdtls.h"
40 #include "transportlayerice.h"
41 #include "transportlayerlog.h"
42 #include "transportlayerloopback.h"
43
44 #include "runnable_utils.h"
45
46 #define GTEST_HAS_RTTI 0
47 #include "gtest/gtest.h"
48 #include "gtest_utils.h"
49
50 using namespace mozilla;
51 MOZ_MTLOG_MODULE("mtransport")
52
53 const uint8_t kTlsChangeCipherSpecType = 0x14;
54 const uint8_t kTlsHandshakeType = 0x16;
55
56 const uint8_t kTlsHandshakeCertificate = 0x0b;
57 const uint8_t kTlsHandshakeServerKeyExchange = 0x0c;
58
59 const uint8_t kTlsFakeChangeCipherSpec[] = {
60 kTlsChangeCipherSpecType, // Type
61 0xfe,
62 0xff, // Version
63 0x00,
64 0x00,
65 0x00,
66 0x00,
67 0x00,
68 0x00,
69 0x00,
70 0x10, // Fictitious sequence #
71 0x00,
72 0x01, // Length
73 0x01 // Value
74 };
75
76 // Layer class which can't be initialized.
77 class TransportLayerDummy : public TransportLayer {
78 public:
TransportLayerDummy(bool allow_init,bool * destroyed)79 TransportLayerDummy(bool allow_init, bool* destroyed)
80 : allow_init_(allow_init), destroyed_(destroyed) {
81 *destroyed_ = false;
82 }
83
~TransportLayerDummy()84 virtual ~TransportLayerDummy() { *destroyed_ = true; }
85
InitInternal()86 nsresult InitInternal() override {
87 return allow_init_ ? NS_OK : NS_ERROR_FAILURE;
88 }
89
SendPacket(MediaPacket & packet)90 TransportResult SendPacket(MediaPacket& packet) override {
91 MOZ_CRASH(); // Should never be called.
92 return 0;
93 }
94
95 TRANSPORT_LAYER_ID("lossy")
96
97 private:
98 bool allow_init_;
99 bool* destroyed_;
100 };
101
102 class Inspector {
103 public:
104 virtual ~Inspector() = default;
105
106 virtual void Inspect(TransportLayer* layer, const unsigned char* data,
107 size_t len) = 0;
108 };
109
110 // Class to simulate various kinds of network lossage
111 class TransportLayerLossy : public TransportLayer {
112 public:
TransportLayerLossy()113 TransportLayerLossy() : loss_mask_(0), packet_(0), inspector_(nullptr) {}
114 ~TransportLayerLossy() = default;
115
SendPacket(MediaPacket & packet)116 TransportResult SendPacket(MediaPacket& packet) override {
117 MOZ_MTLOG(ML_NOTICE, LAYER_INFO << "SendPacket(" << packet.len() << ")");
118
119 if (loss_mask_ & (1 << (packet_ % 32))) {
120 MOZ_MTLOG(ML_NOTICE, "Dropping packet");
121 ++packet_;
122 return packet.len();
123 }
124 if (inspector_) {
125 inspector_->Inspect(this, packet.data(), packet.len());
126 }
127
128 ++packet_;
129
130 return downward_->SendPacket(packet);
131 }
132
SetLoss(uint32_t packet)133 void SetLoss(uint32_t packet) { loss_mask_ |= (1 << (packet & 32)); }
134
SetInspector(UniquePtr<Inspector> inspector)135 void SetInspector(UniquePtr<Inspector> inspector) {
136 inspector_ = std::move(inspector);
137 }
138
StateChange(TransportLayer * layer,State state)139 void StateChange(TransportLayer* layer, State state) { TL_SET_STATE(state); }
140
PacketReceived(TransportLayer * layer,MediaPacket & packet)141 void PacketReceived(TransportLayer* layer, MediaPacket& packet) {
142 SignalPacketReceived(this, packet);
143 }
144
145 TRANSPORT_LAYER_ID("lossy")
146
147 protected:
WasInserted()148 void WasInserted() override {
149 downward_->SignalPacketReceived.connect(
150 this, &TransportLayerLossy::PacketReceived);
151 downward_->SignalStateChange.connect(this,
152 &TransportLayerLossy::StateChange);
153
154 TL_SET_STATE(downward_->state());
155 }
156
157 private:
158 uint32_t loss_mask_;
159 uint32_t packet_;
160 UniquePtr<Inspector> inspector_;
161 };
162
163 // Process DTLS Records
164 #define CHECK_LENGTH(expected) \
165 do { \
166 EXPECT_GE(remaining(), expected); \
167 if (remaining() < expected) return false; \
168 } while (0)
169
170 class TlsParser {
171 public:
TlsParser(const unsigned char * data,size_t len)172 TlsParser(const unsigned char* data, size_t len) : buffer_(), offset_(0) {
173 buffer_.Copy(data, len);
174 }
175
Read(unsigned char * val)176 bool Read(unsigned char* val) {
177 if (remaining() < 1) {
178 return false;
179 }
180 *val = *ptr();
181 consume(1);
182 return true;
183 }
184
185 // Read an integral type of specified width.
Read(uint32_t * val,size_t len)186 bool Read(uint32_t* val, size_t len) {
187 if (len > sizeof(uint32_t)) return false;
188
189 *val = 0;
190
191 for (size_t i = 0; i < len; ++i) {
192 unsigned char tmp;
193
194 if (!Read(&tmp)) return false;
195
196 (*val) = ((*val) << 8) + tmp;
197 }
198
199 return true;
200 }
201
Read(unsigned char * val,size_t len)202 bool Read(unsigned char* val, size_t len) {
203 if (remaining() < len) {
204 return false;
205 }
206
207 if (val) {
208 memcpy(val, ptr(), len);
209 }
210 consume(len);
211
212 return true;
213 }
214
215 private:
remaining() const216 size_t remaining() const { return buffer_.len() - offset_; }
ptr() const217 const uint8_t* ptr() const { return buffer_.data() + offset_; }
consume(size_t len)218 void consume(size_t len) { offset_ += len; }
219
220 MediaPacket buffer_;
221 size_t offset_;
222 };
223
224 class DtlsRecordParser {
225 public:
DtlsRecordParser(const unsigned char * data,size_t len)226 DtlsRecordParser(const unsigned char* data, size_t len)
227 : buffer_(), offset_(0) {
228 buffer_.Copy(data, len);
229 }
230
NextRecord(uint8_t * ct,UniquePtr<MediaPacket> * buffer)231 bool NextRecord(uint8_t* ct, UniquePtr<MediaPacket>* buffer) {
232 if (!remaining()) return false;
233
234 CHECK_LENGTH(13U);
235 const uint8_t* ctp = reinterpret_cast<const uint8_t*>(ptr());
236 consume(11); // ct + version + length
237
238 const uint16_t* tmp = reinterpret_cast<const uint16_t*>(ptr());
239 size_t length = ntohs(*tmp);
240 consume(2);
241
242 CHECK_LENGTH(length);
243 auto db = MakeUnique<MediaPacket>();
244 db->Copy(ptr(), length);
245 consume(length);
246
247 *ct = *ctp;
248 *buffer = std::move(db);
249
250 return true;
251 }
252
253 private:
remaining() const254 size_t remaining() const { return buffer_.len() - offset_; }
ptr() const255 const uint8_t* ptr() const { return buffer_.data() + offset_; }
consume(size_t len)256 void consume(size_t len) { offset_ += len; }
257
258 MediaPacket buffer_;
259 size_t offset_;
260 };
261
262 // Inspector that parses out DTLS records and passes
263 // them on.
264 class DtlsRecordInspector : public Inspector {
265 public:
Inspect(TransportLayer * layer,const unsigned char * data,size_t len)266 virtual void Inspect(TransportLayer* layer, const unsigned char* data,
267 size_t len) {
268 DtlsRecordParser parser(data, len);
269
270 uint8_t ct;
271 UniquePtr<MediaPacket> buf;
272 while (parser.NextRecord(&ct, &buf)) {
273 OnRecord(layer, ct, buf->data(), buf->len());
274 }
275 }
276
277 virtual void OnRecord(TransportLayer* layer, uint8_t content_type,
278 const unsigned char* record, size_t len) = 0;
279 };
280
281 // Inspector that injects arbitrary packets based on
282 // DTLS records of various types.
283 class DtlsInspectorInjector : public DtlsRecordInspector {
284 public:
DtlsInspectorInjector(uint8_t packet_type,uint8_t handshake_type,const unsigned char * data,size_t len)285 DtlsInspectorInjector(uint8_t packet_type, uint8_t handshake_type,
286 const unsigned char* data, size_t len)
287 : packet_type_(packet_type), handshake_type_(handshake_type) {
288 packet_.Copy(data, len);
289 }
290
OnRecord(TransportLayer * layer,uint8_t content_type,const unsigned char * data,size_t len)291 virtual void OnRecord(TransportLayer* layer, uint8_t content_type,
292 const unsigned char* data, size_t len) {
293 // Only inject once.
294 if (!packet_.data()) {
295 return;
296 }
297
298 // Check that the first byte is as requested.
299 if (content_type != packet_type_) {
300 return;
301 }
302
303 if (handshake_type_ != 0xff) {
304 // Check that the packet is plausibly long enough.
305 if (len < 1) {
306 return;
307 }
308
309 // Check that the handshake type is as requested.
310 if (data[0] != handshake_type_) {
311 return;
312 }
313 }
314
315 layer->SendPacket(packet_);
316 packet_.Reset();
317 }
318
319 private:
320 uint8_t packet_type_;
321 uint8_t handshake_type_;
322 MediaPacket packet_;
323 };
324
325 // Make a copy of the first instance of a message.
326 class DtlsInspectorRecordHandshakeMessage : public DtlsRecordInspector {
327 public:
DtlsInspectorRecordHandshakeMessage(uint8_t handshake_type)328 explicit DtlsInspectorRecordHandshakeMessage(uint8_t handshake_type)
329 : handshake_type_(handshake_type), buffer_() {}
330
OnRecord(TransportLayer * layer,uint8_t content_type,const unsigned char * data,size_t len)331 virtual void OnRecord(TransportLayer* layer, uint8_t content_type,
332 const unsigned char* data, size_t len) {
333 // Only do this once.
334 if (buffer_.len()) {
335 return;
336 }
337
338 // Check that the first byte is as requested.
339 if (content_type != kTlsHandshakeType) {
340 return;
341 }
342
343 TlsParser parser(data, len);
344 unsigned char message_type;
345 // Read the handshake message type.
346 if (!parser.Read(&message_type)) {
347 return;
348 }
349 if (message_type != handshake_type_) {
350 return;
351 }
352
353 uint32_t length;
354 if (!parser.Read(&length, 3)) {
355 return;
356 }
357
358 uint32_t message_seq;
359 if (!parser.Read(&message_seq, 2)) {
360 return;
361 }
362
363 uint32_t fragment_offset;
364 if (!parser.Read(&fragment_offset, 3)) {
365 return;
366 }
367
368 uint32_t fragment_length;
369 if (!parser.Read(&fragment_length, 3)) {
370 return;
371 }
372
373 if ((fragment_offset != 0) || (fragment_length != length)) {
374 // This shouldn't happen because all current tests where we
375 // are using this code don't fragment.
376 return;
377 }
378
379 UniquePtr<uint8_t[]> buffer(new uint8_t[length]);
380 if (!parser.Read(buffer.get(), length)) {
381 return;
382 }
383 buffer_.Take(std::move(buffer), length);
384 }
385
buffer()386 const MediaPacket& buffer() { return buffer_; }
387
388 private:
389 uint8_t handshake_type_;
390 MediaPacket buffer_;
391 };
392
393 class TlsServerKeyExchangeECDHE {
394 public:
Parse(const unsigned char * data,size_t len)395 bool Parse(const unsigned char* data, size_t len) {
396 TlsParser parser(data, len);
397
398 uint8_t curve_type;
399 if (!parser.Read(&curve_type)) {
400 return false;
401 }
402
403 if (curve_type != 3) { // named_curve
404 return false;
405 }
406
407 uint32_t named_curve;
408 if (!parser.Read(&named_curve, 2)) {
409 return false;
410 }
411
412 uint32_t point_length;
413 if (!parser.Read(&point_length, 1)) {
414 return false;
415 }
416
417 UniquePtr<uint8_t[]> key(new uint8_t[point_length]);
418 if (!parser.Read(key.get(), point_length)) {
419 return false;
420 }
421 public_key_.Take(std::move(key), point_length);
422
423 return true;
424 }
425
426 MediaPacket public_key_;
427 };
428
429 namespace {
430 class TransportTestPeer : public sigslot::has_slots<> {
431 public:
TransportTestPeer(nsCOMPtr<nsIEventTarget> target,std::string name,MtransportTestUtils * utils)432 TransportTestPeer(nsCOMPtr<nsIEventTarget> target, std::string name,
433 MtransportTestUtils* utils)
434 : name_(name),
435 offerer_(name == "P1"),
436 target_(target),
437 received_packets_(0),
438 received_bytes_(0),
439 flow_(new TransportFlow(name)),
440 loopback_(new TransportLayerLoopback()),
441 logging_(new TransportLayerLogging()),
442 lossy_(new TransportLayerLossy()),
443 dtls_(new TransportLayerDtls()),
444 identity_(DtlsIdentity::Generate()),
445 ice_ctx_(),
446 streams_(),
447 peer_(nullptr),
448 gathering_complete_(false),
449 digest_("sha-1"),
450 enabled_cipersuites_(),
451 disabled_cipersuites_(),
452 test_utils_(utils) {
453 NrIceCtx::InitializeGlobals(NrIceCtx::GlobalConfig());
454 ice_ctx_ = NrIceCtx::Create(name, NrIceCtx::Config());
455 std::vector<NrIceStunServer> stun_servers;
456 UniquePtr<NrIceStunServer> server(NrIceStunServer::Create(
457 std::string((char*)"stun.services.mozilla.com"), 3478));
458 stun_servers.push_back(*server);
459 EXPECT_TRUE(NS_SUCCEEDED(ice_ctx_->SetStunServers(stun_servers)));
460
461 dtls_->SetIdentity(identity_);
462 dtls_->SetRole(offerer_ ? TransportLayerDtls::SERVER
463 : TransportLayerDtls::CLIENT);
464
465 nsresult res = identity_->ComputeFingerprint(&digest_);
466 EXPECT_TRUE(NS_SUCCEEDED(res));
467 EXPECT_EQ(20u, digest_.value_.size());
468 }
469
~TransportTestPeer()470 ~TransportTestPeer() {
471 test_utils_->sts_target()->Dispatch(
472 WrapRunnable(this, &TransportTestPeer::DestroyFlow), NS_DISPATCH_SYNC);
473 }
474
DestroyFlow()475 void DestroyFlow() {
476 disconnect_all();
477 if (flow_) {
478 loopback_->Disconnect();
479 flow_ = nullptr;
480 }
481 ice_ctx_->Destroy();
482 ice_ctx_ = nullptr;
483 streams_.clear();
484 }
485
DisconnectDestroyFlow()486 void DisconnectDestroyFlow() {
487 test_utils_->sts_target()->Dispatch(
488 NS_NewRunnableFunction(
489 __func__,
490 [this] {
491 loopback_->Disconnect();
492 disconnect_all(); // Disconnect from the signals;
493 flow_ = nullptr;
494 }),
495 NS_DISPATCH_SYNC);
496 }
497
SetDtlsAllowAll()498 void SetDtlsAllowAll() {
499 nsresult res = dtls_->SetVerificationAllowAll();
500 ASSERT_TRUE(NS_SUCCEEDED(res));
501 }
502
SetAlpn(std::string str,bool withDefault,std::string extra="")503 void SetAlpn(std::string str, bool withDefault, std::string extra = "") {
504 std::set<std::string> alpn;
505 alpn.insert(str); // the one we want to select
506 if (!extra.empty()) {
507 alpn.insert(extra);
508 }
509 nsresult res = dtls_->SetAlpn(alpn, withDefault ? str : "");
510 ASSERT_EQ(NS_OK, res);
511 }
512
GetAlpn() const513 const std::string& GetAlpn() const { return dtls_->GetNegotiatedAlpn(); }
514
SetDtlsPeer(TransportTestPeer * peer,int digests,unsigned int damage)515 void SetDtlsPeer(TransportTestPeer* peer, int digests, unsigned int damage) {
516 unsigned int mask = 1;
517
518 for (int i = 0; i < digests; i++) {
519 DtlsDigest digest_to_set(peer->digest_);
520
521 if (damage & mask) digest_to_set.value_.data()[0]++;
522
523 nsresult res = dtls_->SetVerificationDigest(digest_to_set);
524
525 ASSERT_TRUE(NS_SUCCEEDED(res));
526
527 mask <<= 1;
528 }
529 }
530
SetupSrtp()531 void SetupSrtp() {
532 std::vector<uint16_t> srtp_ciphers =
533 TransportLayerDtls::GetDefaultSrtpCiphers();
534 SetSrtpCiphers(srtp_ciphers);
535 }
536
SetSrtpCiphers(std::vector<uint16_t> & srtp_ciphers)537 void SetSrtpCiphers(std::vector<uint16_t>& srtp_ciphers) {
538 ASSERT_TRUE(NS_SUCCEEDED(dtls_->SetSrtpCiphers(srtp_ciphers)));
539 }
540
ConnectSocket_s(TransportTestPeer * peer)541 void ConnectSocket_s(TransportTestPeer* peer) {
542 nsresult res;
543 res = loopback_->Init();
544 ASSERT_EQ((nsresult)NS_OK, res);
545
546 loopback_->Connect(peer->loopback_);
547 ASSERT_EQ((nsresult)NS_OK, loopback_->Init());
548 ASSERT_EQ((nsresult)NS_OK, logging_->Init());
549 ASSERT_EQ((nsresult)NS_OK, lossy_->Init());
550 ASSERT_EQ((nsresult)NS_OK, dtls_->Init());
551 dtls_->Chain(lossy_);
552 lossy_->Chain(logging_);
553 logging_->Chain(loopback_);
554
555 flow_->PushLayer(loopback_);
556 flow_->PushLayer(logging_);
557 flow_->PushLayer(lossy_);
558 flow_->PushLayer(dtls_);
559
560 if (dtls_->state() != TransportLayer::TS_ERROR) {
561 // Don't execute these blocks if DTLS didn't initialize.
562 TweakCiphers(dtls_->internal_fd());
563 if (post_setup_) {
564 post_setup_(dtls_->internal_fd());
565 }
566 }
567
568 dtls_->SignalPacketReceived.connect(this,
569 &TransportTestPeer::PacketReceived);
570 }
571
TweakCiphers(PRFileDesc * fd)572 void TweakCiphers(PRFileDesc* fd) {
573 for (unsigned short& enabled_cipersuite : enabled_cipersuites_) {
574 SSL_CipherPrefSet(fd, enabled_cipersuite, PR_TRUE);
575 }
576 for (unsigned short& disabled_cipersuite : disabled_cipersuites_) {
577 SSL_CipherPrefSet(fd, disabled_cipersuite, PR_FALSE);
578 }
579 }
580
ConnectSocket(TransportTestPeer * peer)581 void ConnectSocket(TransportTestPeer* peer) {
582 RUN_ON_THREAD(test_utils_->sts_target(),
583 WrapRunnable(this, &TransportTestPeer::ConnectSocket_s, peer),
584 NS_DISPATCH_SYNC);
585 }
586
InitIce_s()587 nsresult InitIce_s() {
588 nsresult rv = ice_->Init();
589 NS_ENSURE_SUCCESS(rv, rv);
590 rv = dtls_->Init();
591 NS_ENSURE_SUCCESS(rv, rv);
592 dtls_->Chain(ice_);
593 flow_->PushLayer(ice_);
594 flow_->PushLayer(dtls_);
595 return NS_OK;
596 }
597
InitIce()598 void InitIce() {
599 nsresult res;
600
601 // Attach our slots
602 ice_ctx_->SignalGatheringStateChange.connect(
603 this, &TransportTestPeer::GatheringStateChange);
604
605 char name[100];
606 snprintf(name, sizeof(name), "%s:stream%d", name_.c_str(),
607 (int)streams_.size());
608
609 // Create the media stream
610 RefPtr<NrIceMediaStream> stream = ice_ctx_->CreateStream(name, name, 1);
611
612 ASSERT_TRUE(stream != nullptr);
613 stream->SetIceCredentials("ufrag", "pass");
614 streams_.push_back(stream);
615
616 // Listen for candidates
617 stream->SignalCandidate.connect(this, &TransportTestPeer::GotCandidate);
618
619 // Create the transport layer
620 ice_ = new TransportLayerIce();
621 ice_->SetParameters(stream, 1);
622
623 test_utils_->sts_target()->Dispatch(
624 WrapRunnableRet(&res, this, &TransportTestPeer::InitIce_s),
625 NS_DISPATCH_SYNC);
626
627 ASSERT_EQ((nsresult)NS_OK, res);
628
629 // Listen for media events
630 dtls_->SignalPacketReceived.connect(this,
631 &TransportTestPeer::PacketReceived);
632 dtls_->SignalStateChange.connect(this, &TransportTestPeer::StateChanged);
633
634 // Start gathering
635 test_utils_->sts_target()->Dispatch(
636 WrapRunnableRet(&res, ice_ctx_, &NrIceCtx::StartGathering, false,
637 false),
638 NS_DISPATCH_SYNC);
639 ASSERT_TRUE(NS_SUCCEEDED(res));
640 }
641
ConnectIce(TransportTestPeer * peer)642 void ConnectIce(TransportTestPeer* peer) {
643 peer_ = peer;
644
645 // If gathering is already complete, push the candidates over
646 if (gathering_complete_) GatheringComplete();
647 }
648
649 // New candidate
GotCandidate(NrIceMediaStream * stream,const std::string & candidate,const std::string & ufrag,const std::string & mdns_addr,const std::string & actual_addr)650 void GotCandidate(NrIceMediaStream* stream, const std::string& candidate,
651 const std::string& ufrag, const std::string& mdns_addr,
652 const std::string& actual_addr) {
653 std::cerr << "Got candidate " << candidate << " (ufrag=" << ufrag << ")"
654 << std::endl;
655 }
656
GatheringStateChange(NrIceCtx * ctx,NrIceCtx::GatheringState state)657 void GatheringStateChange(NrIceCtx* ctx, NrIceCtx::GatheringState state) {
658 (void)ctx;
659 if (state == NrIceCtx::ICE_CTX_GATHER_COMPLETE) {
660 GatheringComplete();
661 }
662 }
663
664 // Gathering complete, so send our candidates and start
665 // connecting on the other peer.
GatheringComplete()666 void GatheringComplete() {
667 nsresult res;
668
669 // Don't send to the other side
670 if (!peer_) {
671 gathering_complete_ = true;
672 return;
673 }
674
675 // First send attributes
676 test_utils_->sts_target()->Dispatch(
677 WrapRunnableRet(&res, peer_->ice_ctx_, &NrIceCtx::ParseGlobalAttributes,
678 ice_ctx_->GetGlobalAttributes()),
679 NS_DISPATCH_SYNC);
680 ASSERT_TRUE(NS_SUCCEEDED(res));
681
682 for (size_t i = 0; i < streams_.size(); ++i) {
683 test_utils_->sts_target()->Dispatch(
684 WrapRunnableRet(&res, peer_->streams_[i],
685 &NrIceMediaStream::ConnectToPeer, "ufrag", "pass",
686 streams_[i]->GetAttributes()),
687 NS_DISPATCH_SYNC);
688
689 ASSERT_TRUE(NS_SUCCEEDED(res));
690 }
691
692 // Start checks on the other peer.
693 test_utils_->sts_target()->Dispatch(
694 WrapRunnableRet(&res, peer_->ice_ctx_, &NrIceCtx::StartChecks),
695 NS_DISPATCH_SYNC);
696 ASSERT_TRUE(NS_SUCCEEDED(res));
697 }
698
699 // WrapRunnable/lambda and move semantics (MediaPacket is not copyable) don't
700 // get along yet, so we need a wrapper. Gross.
SendPacketWrapper(TransportLayer * layer,MediaPacket * packet)701 static TransportResult SendPacketWrapper(TransportLayer* layer,
702 MediaPacket* packet) {
703 return layer->SendPacket(*packet);
704 }
705
SendPacket(MediaPacket & packet)706 TransportResult SendPacket(MediaPacket& packet) {
707 TransportResult ret;
708
709 test_utils_->sts_target()->Dispatch(
710 WrapRunnableNMRet(&ret, &TransportTestPeer::SendPacketWrapper, dtls_,
711 &packet),
712 NS_DISPATCH_SYNC);
713
714 return ret;
715 }
716
StateChanged(TransportLayer * layer,TransportLayer::State state)717 void StateChanged(TransportLayer* layer, TransportLayer::State state) {
718 if (state == TransportLayer::TS_OPEN) {
719 std::cerr << "Now connected" << std::endl;
720 }
721 }
722
PacketReceived(TransportLayer * layer,MediaPacket & packet)723 void PacketReceived(TransportLayer* layer, MediaPacket& packet) {
724 std::cerr << "Received " << packet.len() << " bytes" << std::endl;
725 ++received_packets_;
726 received_bytes_ += packet.len();
727 }
728
SetLoss(uint32_t loss)729 void SetLoss(uint32_t loss) { lossy_->SetLoss(loss); }
730
SetCombinePackets(bool combine)731 void SetCombinePackets(bool combine) { loopback_->CombinePackets(combine); }
732
SetInspector(UniquePtr<Inspector> inspector)733 void SetInspector(UniquePtr<Inspector> inspector) {
734 lossy_->SetInspector(std::move(inspector));
735 }
736
SetInspector(Inspector * in)737 void SetInspector(Inspector* in) {
738 UniquePtr<Inspector> inspector(in);
739
740 lossy_->SetInspector(std::move(inspector));
741 }
742
SetCipherSuiteChanges(const std::vector<uint16_t> & enableThese,const std::vector<uint16_t> & disableThese)743 void SetCipherSuiteChanges(const std::vector<uint16_t>& enableThese,
744 const std::vector<uint16_t>& disableThese) {
745 disabled_cipersuites_ = disableThese;
746 enabled_cipersuites_ = enableThese;
747 }
748
SetPostSetup(const std::function<void (PRFileDesc *)> & setup)749 void SetPostSetup(const std::function<void(PRFileDesc*)>& setup) {
750 post_setup_ = std::move(setup);
751 }
752
state()753 TransportLayer::State state() {
754 TransportLayer::State tstate;
755
756 RUN_ON_THREAD(test_utils_->sts_target(),
757 WrapRunnableRet(&tstate, dtls_, &TransportLayer::state));
758
759 return tstate;
760 }
761
connected()762 bool connected() { return state() == TransportLayer::TS_OPEN; }
763
failed()764 bool failed() { return state() == TransportLayer::TS_ERROR; }
765
receivedPackets()766 size_t receivedPackets() { return received_packets_; }
767
receivedBytes()768 size_t receivedBytes() { return received_bytes_; }
769
cipherSuite() const770 uint16_t cipherSuite() const {
771 nsresult rv;
772 uint16_t cipher;
773 RUN_ON_THREAD(
774 test_utils_->sts_target(),
775 WrapRunnableRet(&rv, dtls_, &TransportLayerDtls::GetCipherSuite,
776 &cipher));
777
778 if (NS_FAILED(rv)) {
779 return TLS_NULL_WITH_NULL_NULL; // i.e., not good
780 }
781 return cipher;
782 }
783
srtpCipher() const784 uint16_t srtpCipher() const {
785 nsresult rv;
786 uint16_t cipher;
787 RUN_ON_THREAD(test_utils_->sts_target(),
788 WrapRunnableRet(&rv, dtls_,
789 &TransportLayerDtls::GetSrtpCipher, &cipher));
790 if (NS_FAILED(rv)) {
791 return 0; // the SRTP equivalent of TLS_NULL_WITH_NULL_NULL
792 }
793 return cipher;
794 }
795
796 private:
797 std::string name_;
798 bool offerer_;
799 nsCOMPtr<nsIEventTarget> target_;
800 size_t received_packets_;
801 size_t received_bytes_;
802 RefPtr<TransportFlow> flow_;
803 TransportLayerLoopback* loopback_;
804 TransportLayerLogging* logging_;
805 TransportLayerLossy* lossy_;
806 TransportLayerDtls* dtls_;
807 TransportLayerIce* ice_;
808 RefPtr<DtlsIdentity> identity_;
809 RefPtr<NrIceCtx> ice_ctx_;
810 std::vector<RefPtr<NrIceMediaStream> > streams_;
811 TransportTestPeer* peer_;
812 bool gathering_complete_;
813 DtlsDigest digest_;
814 std::vector<uint16_t> enabled_cipersuites_;
815 std::vector<uint16_t> disabled_cipersuites_;
816 MtransportTestUtils* test_utils_;
817 std::function<void(PRFileDesc* fd)> post_setup_ = nullptr;
818 };
819
820 class TransportTest : public MtransportTest {
821 public:
TransportTest()822 TransportTest() {
823 fds_[0] = nullptr;
824 fds_[1] = nullptr;
825 p1_ = nullptr;
826 p2_ = nullptr;
827 }
828
TearDown()829 void TearDown() override {
830 delete p1_;
831 delete p2_;
832
833 // Can't detach these
834 // PR_Close(fds_[0]);
835 // PR_Close(fds_[1]);
836 MtransportTest::TearDown();
837 }
838
DestroyPeerFlows()839 void DestroyPeerFlows() {
840 p1_->DisconnectDestroyFlow();
841 p2_->DisconnectDestroyFlow();
842 }
843
SetUp()844 void SetUp() override {
845 MtransportTest::SetUp();
846
847 nsresult rv;
848 target_ = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv);
849 ASSERT_TRUE(NS_SUCCEEDED(rv));
850
851 Reset();
852 }
853
Reset()854 void Reset() {
855 if (p1_) {
856 delete p1_;
857 }
858 if (p2_) {
859 delete p2_;
860 }
861 p1_ = new TransportTestPeer(target_, "P1", test_utils_);
862 p2_ = new TransportTestPeer(target_, "P2", test_utils_);
863 }
864
SetupSrtp()865 void SetupSrtp() {
866 p1_->SetupSrtp();
867 p2_->SetupSrtp();
868 }
869
SetDtlsPeer(int digests=1,unsigned int damage=0)870 void SetDtlsPeer(int digests = 1, unsigned int damage = 0) {
871 p1_->SetDtlsPeer(p2_, digests, damage);
872 p2_->SetDtlsPeer(p1_, digests, damage);
873 }
874
SetDtlsAllowAll()875 void SetDtlsAllowAll() {
876 p1_->SetDtlsAllowAll();
877 p2_->SetDtlsAllowAll();
878 }
879
SetAlpn(std::string first,std::string second,bool withDefaults=true)880 void SetAlpn(std::string first, std::string second,
881 bool withDefaults = true) {
882 if (!first.empty()) {
883 p1_->SetAlpn(first, withDefaults, "bogus");
884 }
885 if (!second.empty()) {
886 p2_->SetAlpn(second, withDefaults);
887 }
888 }
889
CheckAlpn(std::string first,std::string second)890 void CheckAlpn(std::string first, std::string second) {
891 ASSERT_EQ(first, p1_->GetAlpn());
892 ASSERT_EQ(second, p2_->GetAlpn());
893 }
894
ConnectSocket()895 void ConnectSocket() {
896 ConnectSocketInternal();
897 ASSERT_TRUE_WAIT(p1_->connected(), 10000);
898 ASSERT_TRUE_WAIT(p2_->connected(), 10000);
899
900 ASSERT_EQ(p1_->cipherSuite(), p2_->cipherSuite());
901 ASSERT_EQ(p1_->srtpCipher(), p2_->srtpCipher());
902 }
903
ConnectSocketExpectFail()904 void ConnectSocketExpectFail() {
905 ConnectSocketInternal();
906 ASSERT_TRUE_WAIT(p1_->failed(), 10000);
907 ASSERT_TRUE_WAIT(p2_->failed(), 10000);
908 }
909
ConnectSocketExpectState(TransportLayer::State s1,TransportLayer::State s2)910 void ConnectSocketExpectState(TransportLayer::State s1,
911 TransportLayer::State s2) {
912 ConnectSocketInternal();
913 ASSERT_EQ_WAIT(s1, p1_->state(), 10000);
914 ASSERT_EQ_WAIT(s2, p2_->state(), 10000);
915 }
916
ConnectIce()917 void ConnectIce() {
918 p1_->InitIce();
919 p2_->InitIce();
920 p1_->ConnectIce(p2_);
921 p2_->ConnectIce(p1_);
922 ASSERT_TRUE_WAIT(p1_->connected(), 10000);
923 ASSERT_TRUE_WAIT(p2_->connected(), 10000);
924 }
925
TransferTest(size_t count,size_t bytes=1024)926 void TransferTest(size_t count, size_t bytes = 1024) {
927 unsigned char buf[bytes];
928
929 for (size_t i = 0; i < count; ++i) {
930 memset(buf, count & 0xff, sizeof(buf));
931 MediaPacket packet;
932 packet.Copy(buf, sizeof(buf));
933 TransportResult rv = p1_->SendPacket(packet);
934 ASSERT_TRUE(rv > 0);
935 }
936
937 std::cerr << "Received == " << p2_->receivedPackets() << " packets"
938 << std::endl;
939 ASSERT_TRUE_WAIT(count == p2_->receivedPackets(), 10000);
940 ASSERT_TRUE((count * sizeof(buf)) == p2_->receivedBytes());
941 }
942
943 protected:
ConnectSocketInternal()944 void ConnectSocketInternal() {
945 test_utils_->sts_target()->Dispatch(
946 WrapRunnable(p1_, &TransportTestPeer::ConnectSocket, p2_),
947 NS_DISPATCH_SYNC);
948 test_utils_->sts_target()->Dispatch(
949 WrapRunnable(p2_, &TransportTestPeer::ConnectSocket, p1_),
950 NS_DISPATCH_SYNC);
951 }
952
953 PRFileDesc* fds_[2];
954 TransportTestPeer* p1_;
955 TransportTestPeer* p2_;
956 nsCOMPtr<nsIEventTarget> target_;
957 };
958
TEST_F(TransportTest,TestNoDtlsVerificationSettings)959 TEST_F(TransportTest, TestNoDtlsVerificationSettings) {
960 ConnectSocketExpectFail();
961 }
962
DisableChaCha(TransportTestPeer * peer)963 static void DisableChaCha(TransportTestPeer* peer) {
964 // On ARM, ChaCha20Poly1305 might be preferred; disable it for the tests that
965 // want to check the cipher suite. It doesn't matter which peer disables the
966 // suite, disabling on either side has the same effect.
967 std::vector<uint16_t> chachaSuites;
968 chachaSuites.push_back(TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256);
969 chachaSuites.push_back(TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256);
970 peer->SetCipherSuiteChanges(std::vector<uint16_t>(), chachaSuites);
971 }
972
TEST_F(TransportTest,TestConnect)973 TEST_F(TransportTest, TestConnect) {
974 SetDtlsPeer();
975 DisableChaCha(p1_);
976 ConnectSocket();
977
978 // check that we got the right suite
979 ASSERT_EQ(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, p1_->cipherSuite());
980
981 // no SRTP on this one
982 ASSERT_EQ(0, p1_->srtpCipher());
983 }
984
TEST_F(TransportTest,TestConnectSrtp)985 TEST_F(TransportTest, TestConnectSrtp) {
986 SetupSrtp();
987 SetDtlsPeer();
988 DisableChaCha(p2_);
989 ConnectSocket();
990
991 ASSERT_EQ(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, p1_->cipherSuite());
992
993 // SRTP is on with default value
994 ASSERT_EQ(kDtlsSrtpAeadAes128Gcm, p1_->srtpCipher());
995 }
996
TEST_F(TransportTest,TestConnectDestroyFlowsMainThread)997 TEST_F(TransportTest, TestConnectDestroyFlowsMainThread) {
998 SetDtlsPeer();
999 ConnectSocket();
1000 DestroyPeerFlows();
1001 }
1002
TEST_F(TransportTest,TestConnectAllowAll)1003 TEST_F(TransportTest, TestConnectAllowAll) {
1004 SetDtlsAllowAll();
1005 ConnectSocket();
1006 }
1007
TEST_F(TransportTest,TestConnectAlpn)1008 TEST_F(TransportTest, TestConnectAlpn) {
1009 SetDtlsPeer();
1010 SetAlpn("a", "a");
1011 ConnectSocket();
1012 CheckAlpn("a", "a");
1013 }
1014
TEST_F(TransportTest,TestConnectAlpnMismatch)1015 TEST_F(TransportTest, TestConnectAlpnMismatch) {
1016 SetDtlsPeer();
1017 SetAlpn("something", "different");
1018 ConnectSocketExpectFail();
1019 }
1020
TEST_F(TransportTest,TestConnectAlpnServerDefault)1021 TEST_F(TransportTest, TestConnectAlpnServerDefault) {
1022 SetDtlsPeer();
1023 SetAlpn("def", "");
1024 // server allows default, client doesn't support
1025 ConnectSocket();
1026 CheckAlpn("def", "");
1027 }
1028
TEST_F(TransportTest,TestConnectAlpnClientDefault)1029 TEST_F(TransportTest, TestConnectAlpnClientDefault) {
1030 SetDtlsPeer();
1031 SetAlpn("", "clientdef");
1032 // client allows default, but server will ignore the extension
1033 ConnectSocket();
1034 CheckAlpn("", "clientdef");
1035 }
1036
TEST_F(TransportTest,TestConnectClientNoAlpn)1037 TEST_F(TransportTest, TestConnectClientNoAlpn) {
1038 SetDtlsPeer();
1039 // Here the server has ALPN, but no default is allowed.
1040 // Reminder: p1 == server, p2 == client
1041 SetAlpn("server-nodefault", "", false);
1042 // The server doesn't see the extension, so negotiates without it.
1043 // But then the server is forced to close when it discovers that ALPN wasn't
1044 // negotiated; the client sees a close.
1045 ConnectSocketExpectState(TransportLayer::TS_ERROR, TransportLayer::TS_CLOSED);
1046 }
1047
TEST_F(TransportTest,TestConnectServerNoAlpn)1048 TEST_F(TransportTest, TestConnectServerNoAlpn) {
1049 SetDtlsPeer();
1050 SetAlpn("", "client-nodefault", false);
1051 // The client aborts; the server doesn't realize this is a problem and just
1052 // sees the close.
1053 ConnectSocketExpectState(TransportLayer::TS_CLOSED, TransportLayer::TS_ERROR);
1054 }
1055
TEST_F(TransportTest,TestConnectNoDigest)1056 TEST_F(TransportTest, TestConnectNoDigest) {
1057 SetDtlsPeer(0, 0);
1058
1059 ConnectSocketExpectFail();
1060 }
1061
TEST_F(TransportTest,TestConnectBadDigest)1062 TEST_F(TransportTest, TestConnectBadDigest) {
1063 SetDtlsPeer(1, 1);
1064
1065 ConnectSocketExpectFail();
1066 }
1067
TEST_F(TransportTest,TestConnectTwoDigests)1068 TEST_F(TransportTest, TestConnectTwoDigests) {
1069 SetDtlsPeer(2, 0);
1070
1071 ConnectSocket();
1072 }
1073
TEST_F(TransportTest,TestConnectTwoDigestsFirstBad)1074 TEST_F(TransportTest, TestConnectTwoDigestsFirstBad) {
1075 SetDtlsPeer(2, 1);
1076
1077 ConnectSocket();
1078 }
1079
TEST_F(TransportTest,TestConnectTwoDigestsSecondBad)1080 TEST_F(TransportTest, TestConnectTwoDigestsSecondBad) {
1081 SetDtlsPeer(2, 2);
1082
1083 ConnectSocket();
1084 }
1085
TEST_F(TransportTest,TestConnectTwoDigestsBothBad)1086 TEST_F(TransportTest, TestConnectTwoDigestsBothBad) {
1087 SetDtlsPeer(2, 3);
1088
1089 ConnectSocketExpectFail();
1090 }
1091
TEST_F(TransportTest,TestConnectInjectCCS)1092 TEST_F(TransportTest, TestConnectInjectCCS) {
1093 SetDtlsPeer();
1094 p2_->SetInspector(MakeUnique<DtlsInspectorInjector>(
1095 kTlsHandshakeType, kTlsHandshakeCertificate, kTlsFakeChangeCipherSpec,
1096 sizeof(kTlsFakeChangeCipherSpec)));
1097
1098 ConnectSocket();
1099 }
1100
TEST_F(TransportTest,TestConnectVerifyNewECDHE)1101 TEST_F(TransportTest, TestConnectVerifyNewECDHE) {
1102 SetDtlsPeer();
1103 DtlsInspectorRecordHandshakeMessage* i1 =
1104 new DtlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
1105 p1_->SetInspector(i1);
1106 ConnectSocket();
1107 TlsServerKeyExchangeECDHE dhe1;
1108 ASSERT_TRUE(dhe1.Parse(i1->buffer().data(), i1->buffer().len()));
1109
1110 Reset();
1111 SetDtlsPeer();
1112 DtlsInspectorRecordHandshakeMessage* i2 =
1113 new DtlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
1114 p1_->SetInspector(i2);
1115 ConnectSocket();
1116 TlsServerKeyExchangeECDHE dhe2;
1117 ASSERT_TRUE(dhe2.Parse(i2->buffer().data(), i2->buffer().len()));
1118
1119 // Now compare these two to see if they are the same.
1120 ASSERT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) &&
1121 (!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(),
1122 dhe1.public_key_.len())));
1123 }
1124
TEST_F(TransportTest,TestConnectVerifyReusedECDHE)1125 TEST_F(TransportTest, TestConnectVerifyReusedECDHE) {
1126 auto set_reuse_ecdhe_key = [](PRFileDesc* fd) {
1127 // TransportLayerDtls automatically sets this pref to false
1128 // so set it back for test.
1129 // This is pretty gross. Dig directly into the NSS FD. The problem
1130 // is that we are testing a feature which TransaportLayerDtls doesn't
1131 // expose.
1132 SECStatus rv = SSL_OptionSet(fd, SSL_REUSE_SERVER_ECDHE_KEY, PR_TRUE);
1133 ASSERT_EQ(SECSuccess, rv);
1134 };
1135
1136 SetDtlsPeer();
1137 DtlsInspectorRecordHandshakeMessage* i1 =
1138 new DtlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
1139 p1_->SetInspector(i1);
1140 p1_->SetPostSetup(set_reuse_ecdhe_key);
1141 ConnectSocket();
1142 TlsServerKeyExchangeECDHE dhe1;
1143 ASSERT_TRUE(dhe1.Parse(i1->buffer().data(), i1->buffer().len()));
1144
1145 Reset();
1146 SetDtlsPeer();
1147 DtlsInspectorRecordHandshakeMessage* i2 =
1148 new DtlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
1149
1150 p1_->SetInspector(i2);
1151 p1_->SetPostSetup(set_reuse_ecdhe_key);
1152
1153 ConnectSocket();
1154 TlsServerKeyExchangeECDHE dhe2;
1155 ASSERT_TRUE(dhe2.Parse(i2->buffer().data(), i2->buffer().len()));
1156
1157 // Now compare these two to see if they are the same.
1158 ASSERT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len());
1159 ASSERT_TRUE(!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(),
1160 dhe1.public_key_.len()));
1161 }
1162
TEST_F(TransportTest,TestTransfer)1163 TEST_F(TransportTest, TestTransfer) {
1164 SetDtlsPeer();
1165 ConnectSocket();
1166 TransferTest(1);
1167 }
1168
TEST_F(TransportTest,TestTransferMaxSize)1169 TEST_F(TransportTest, TestTransferMaxSize) {
1170 SetDtlsPeer();
1171 ConnectSocket();
1172 /* transportlayerdtls uses a 9216 bytes buffer - as this test uses the
1173 * loopback implementation it does not have to take into account the extra
1174 * bytes added by the DTLS layer below. */
1175 TransferTest(1, 9216);
1176 }
1177
TEST_F(TransportTest,TestTransferMultiple)1178 TEST_F(TransportTest, TestTransferMultiple) {
1179 SetDtlsPeer();
1180 ConnectSocket();
1181 TransferTest(3);
1182 }
1183
TEST_F(TransportTest,TestTransferCombinedPackets)1184 TEST_F(TransportTest, TestTransferCombinedPackets) {
1185 SetDtlsPeer();
1186 ConnectSocket();
1187 p2_->SetCombinePackets(true);
1188 TransferTest(3);
1189 }
1190
TEST_F(TransportTest,TestConnectLoseFirst)1191 TEST_F(TransportTest, TestConnectLoseFirst) {
1192 SetDtlsPeer();
1193 p1_->SetLoss(0);
1194 ConnectSocket();
1195 TransferTest(1);
1196 }
1197
TEST_F(TransportTest,TestConnectIce)1198 TEST_F(TransportTest, TestConnectIce) {
1199 SetDtlsPeer();
1200 ConnectIce();
1201 }
1202
TEST_F(TransportTest,TestTransferIceMaxSize)1203 TEST_F(TransportTest, TestTransferIceMaxSize) {
1204 SetDtlsPeer();
1205 ConnectIce();
1206 /* nICEr and transportlayerdtls both use 9216 bytes buffers. But the DTLS
1207 * layer add extra bytes to the packet, which size depends on chosen cipher
1208 * etc. Sending more then 9216 bytes works, but on the receiving side the call
1209 * to PR_recvfrom() will truncate any packet bigger then nICEr's buffer size
1210 * of 9216 bytes, which then results in the DTLS layer discarding the packet.
1211 * Therefore we leave some headroom (according to
1212 * https://bugzilla.mozilla.org/show_bug.cgi?id=1214269#c29 256 bytes should
1213 * be save choice) here for the DTLS bytes to make it safely into the
1214 * receiving buffer in nICEr. */
1215 TransferTest(1, 8960);
1216 }
1217
TEST_F(TransportTest,TestTransferIceMultiple)1218 TEST_F(TransportTest, TestTransferIceMultiple) {
1219 SetDtlsPeer();
1220 ConnectIce();
1221 TransferTest(3);
1222 }
1223
TEST_F(TransportTest,TestTransferIceCombinedPackets)1224 TEST_F(TransportTest, TestTransferIceCombinedPackets) {
1225 SetDtlsPeer();
1226 ConnectIce();
1227 p2_->SetCombinePackets(true);
1228 TransferTest(3);
1229 }
1230
1231 // test the default configuration against a peer that supports only
1232 // one of the mandatory-to-implement suites, which should succeed
ConfigureOneCipher(TransportTestPeer * peer,uint16_t suite)1233 static void ConfigureOneCipher(TransportTestPeer* peer, uint16_t suite) {
1234 std::vector<uint16_t> justOne;
1235 justOne.push_back(suite);
1236 std::vector<uint16_t> everythingElse(
1237 SSL_GetImplementedCiphers(),
1238 SSL_GetImplementedCiphers() + SSL_GetNumImplementedCiphers());
1239 std::remove(everythingElse.begin(), everythingElse.end(), suite);
1240 peer->SetCipherSuiteChanges(justOne, everythingElse);
1241 }
1242
TEST_F(TransportTest,TestCipherMismatch)1243 TEST_F(TransportTest, TestCipherMismatch) {
1244 SetDtlsPeer();
1245 ConfigureOneCipher(p1_, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256);
1246 ConfigureOneCipher(p2_, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA);
1247 ConnectSocketExpectFail();
1248 }
1249
TEST_F(TransportTest,TestCipherMandatoryOnlyGcm)1250 TEST_F(TransportTest, TestCipherMandatoryOnlyGcm) {
1251 SetDtlsPeer();
1252 ConfigureOneCipher(p1_, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256);
1253 ConnectSocket();
1254 ASSERT_EQ(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, p1_->cipherSuite());
1255 }
1256
TEST_F(TransportTest,TestCipherMandatoryOnlyCbc)1257 TEST_F(TransportTest, TestCipherMandatoryOnlyCbc) {
1258 SetDtlsPeer();
1259 ConfigureOneCipher(p1_, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA);
1260 ConnectSocket();
1261 ASSERT_EQ(TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, p1_->cipherSuite());
1262 }
1263
TEST_F(TransportTest,TestSrtpMismatch)1264 TEST_F(TransportTest, TestSrtpMismatch) {
1265 std::vector<uint16_t> setA;
1266 setA.push_back(kDtlsSrtpAes128CmHmacSha1_80);
1267 std::vector<uint16_t> setB;
1268 setB.push_back(kDtlsSrtpAes128CmHmacSha1_32);
1269
1270 p1_->SetSrtpCiphers(setA);
1271 p2_->SetSrtpCiphers(setB);
1272 SetDtlsPeer();
1273 ConnectSocketExpectFail();
1274
1275 ASSERT_EQ(0, p1_->srtpCipher());
1276 ASSERT_EQ(0, p2_->srtpCipher());
1277 }
1278
NoopXtnHandler(PRFileDesc * fd,SSLHandshakeType message,const uint8_t * data,unsigned int len,SSLAlertDescription * alert,void * arg)1279 static SECStatus NoopXtnHandler(PRFileDesc* fd, SSLHandshakeType message,
1280 const uint8_t* data, unsigned int len,
1281 SSLAlertDescription* alert, void* arg) {
1282 return SECSuccess;
1283 }
1284
WriteFixedXtn(PRFileDesc * fd,SSLHandshakeType message,uint8_t * data,unsigned int * len,unsigned int max_len,void * arg)1285 static PRBool WriteFixedXtn(PRFileDesc* fd, SSLHandshakeType message,
1286 uint8_t* data, unsigned int* len,
1287 unsigned int max_len, void* arg) {
1288 // When we enable TLS 1.3, change ssl_hs_server_hello here to
1289 // ssl_hs_encrypted_extensions. At the same time, add a test that writes to
1290 // ssl_hs_server_hello, which should fail.
1291 if (message != ssl_hs_client_hello && message != ssl_hs_server_hello) {
1292 return false;
1293 }
1294
1295 auto v = reinterpret_cast<std::vector<uint8_t>*>(arg);
1296 memcpy(data, &((*v)[0]), v->size());
1297 *len = v->size();
1298 return true;
1299 }
1300
1301 // Note that |value| needs to be readable after this function returns.
InstallBadSrtpExtensionWriter(TransportTestPeer * peer,std::vector<uint8_t> * value)1302 static void InstallBadSrtpExtensionWriter(TransportTestPeer* peer,
1303 std::vector<uint8_t>* value) {
1304 peer->SetPostSetup([value](PRFileDesc* fd) {
1305 // Override the handler that is installed by the DTLS setup.
1306 SECStatus rv = SSL_InstallExtensionHooks(
1307 fd, ssl_use_srtp_xtn, WriteFixedXtn, value, NoopXtnHandler, nullptr);
1308 ASSERT_EQ(SECSuccess, rv);
1309 });
1310 }
1311
TEST_F(TransportTest,TestSrtpErrorServerSendsTwoSrtpCiphers)1312 TEST_F(TransportTest, TestSrtpErrorServerSendsTwoSrtpCiphers) {
1313 // Server (p1_) sends an extension with two values, and empty MKI.
1314 std::vector<uint8_t> xtn = {0x04, 0x00, 0x01, 0x00, 0x02, 0x00};
1315 InstallBadSrtpExtensionWriter(p1_, &xtn);
1316 SetupSrtp();
1317 SetDtlsPeer();
1318 ConnectSocketExpectFail();
1319 }
1320
TEST_F(TransportTest,TestSrtpErrorServerSendsTwoMki)1321 TEST_F(TransportTest, TestSrtpErrorServerSendsTwoMki) {
1322 // Server (p1_) sends an MKI.
1323 std::vector<uint8_t> xtn = {0x02, 0x00, 0x01, 0x01, 0x00};
1324 InstallBadSrtpExtensionWriter(p1_, &xtn);
1325 SetupSrtp();
1326 SetDtlsPeer();
1327 ConnectSocketExpectFail();
1328 }
1329
TEST_F(TransportTest,TestSrtpErrorServerSendsUnknownValue)1330 TEST_F(TransportTest, TestSrtpErrorServerSendsUnknownValue) {
1331 std::vector<uint8_t> xtn = {0x02, 0x9a, 0xf1, 0x00};
1332 InstallBadSrtpExtensionWriter(p1_, &xtn);
1333 SetupSrtp();
1334 SetDtlsPeer();
1335 ConnectSocketExpectFail();
1336 }
1337
TEST_F(TransportTest,TestSrtpErrorServerSendsOverflow)1338 TEST_F(TransportTest, TestSrtpErrorServerSendsOverflow) {
1339 std::vector<uint8_t> xtn = {0x32, 0x00, 0x01, 0x00};
1340 InstallBadSrtpExtensionWriter(p1_, &xtn);
1341 SetupSrtp();
1342 SetDtlsPeer();
1343 ConnectSocketExpectFail();
1344 }
1345
TEST_F(TransportTest,TestSrtpErrorServerSendsUnevenList)1346 TEST_F(TransportTest, TestSrtpErrorServerSendsUnevenList) {
1347 std::vector<uint8_t> xtn = {0x01, 0x00, 0x00};
1348 InstallBadSrtpExtensionWriter(p1_, &xtn);
1349 SetupSrtp();
1350 SetDtlsPeer();
1351 ConnectSocketExpectFail();
1352 }
1353
TEST_F(TransportTest,TestSrtpErrorClientSendsUnevenList)1354 TEST_F(TransportTest, TestSrtpErrorClientSendsUnevenList) {
1355 std::vector<uint8_t> xtn = {0x01, 0x00, 0x00};
1356 InstallBadSrtpExtensionWriter(p2_, &xtn);
1357 SetupSrtp();
1358 SetDtlsPeer();
1359 ConnectSocketExpectFail();
1360 }
1361
TEST_F(TransportTest,OnlyServerSendsSrtpXtn)1362 TEST_F(TransportTest, OnlyServerSendsSrtpXtn) {
1363 p1_->SetupSrtp();
1364 SetDtlsPeer();
1365 // This should connect, but with no SRTP extension neogtiated.
1366 // The client side might negotiate a data channel only.
1367 ConnectSocket();
1368 ASSERT_NE(TLS_NULL_WITH_NULL_NULL, p1_->cipherSuite());
1369 ASSERT_EQ(0, p1_->srtpCipher());
1370 }
1371
TEST_F(TransportTest,OnlyClientSendsSrtpXtn)1372 TEST_F(TransportTest, OnlyClientSendsSrtpXtn) {
1373 p2_->SetupSrtp();
1374 SetDtlsPeer();
1375 // This should connect, but with no SRTP extension neogtiated.
1376 // The server side might negotiate a data channel only.
1377 ConnectSocket();
1378 ASSERT_NE(TLS_NULL_WITH_NULL_NULL, p1_->cipherSuite());
1379 ASSERT_EQ(0, p1_->srtpCipher());
1380 }
1381
1382 class TransportSrtpParameterTest
1383 : public TransportTest,
1384 public ::testing::WithParamInterface<uint16_t> {};
1385
1386 INSTANTIATE_TEST_CASE_P(
1387 SrtpParamInit, TransportSrtpParameterTest,
1388 ::testing::ValuesIn(TransportLayerDtls::GetDefaultSrtpCiphers()));
1389
TEST_P(TransportSrtpParameterTest,TestSrtpCiphersMismatchCombinations)1390 TEST_P(TransportSrtpParameterTest, TestSrtpCiphersMismatchCombinations) {
1391 uint16_t cipher = GetParam();
1392 std::cerr << "Checking cipher: " << cipher << std::endl;
1393
1394 p1_->SetupSrtp();
1395
1396 std::vector<uint16_t> setB;
1397 setB.push_back(cipher);
1398
1399 p2_->SetSrtpCiphers(setB);
1400 SetDtlsPeer();
1401 ConnectSocket();
1402
1403 ASSERT_EQ(cipher, p1_->srtpCipher());
1404 ASSERT_EQ(cipher, p2_->srtpCipher());
1405 }
1406
1407 // NSS doesn't support DHE suites on the server end.
1408 // This checks to see if we barf when that's the only option available.
TEST_F(TransportTest,TestDheOnlyFails)1409 TEST_F(TransportTest, TestDheOnlyFails) {
1410 SetDtlsPeer();
1411
1412 // p2_ is the client
1413 // setting this on p1_ (the server) causes NSS to assert
1414 ConfigureOneCipher(p2_, TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
1415 ConnectSocketExpectFail();
1416 }
1417
1418 } // end namespace
1419