1 
2 /* vim: set ts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this file,
5  * You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 // Original author: ekr@rtfm.com
8 
9 #include <iostream>
10 #include <string>
11 #include <map>
12 #include <algorithm>
13 #include <functional>
14 
15 #ifdef XP_MACOSX
16 // ensure that Apple Security kit enum goes before "sslproto.h"
17 #  include <CoreFoundation/CFAvailability.h>
18 #  include <Security/CipherSuite.h>
19 #endif
20 
21 #include "mozilla/UniquePtr.h"
22 
23 #include "sigslot.h"
24 
25 #include "logging.h"
26 #include "ssl.h"
27 #include "sslexp.h"
28 #include "sslproto.h"
29 
30 #include "nsThreadUtils.h"
31 #include "nsXPCOM.h"
32 
33 #include "mediapacket.h"
34 #include "dtlsidentity.h"
35 #include "nricectx.h"
36 #include "nricemediastream.h"
37 #include "transportflow.h"
38 #include "transportlayer.h"
39 #include "transportlayerdtls.h"
40 #include "transportlayerice.h"
41 #include "transportlayerlog.h"
42 #include "transportlayerloopback.h"
43 
44 #include "runnable_utils.h"
45 
46 #define GTEST_HAS_RTTI 0
47 #include "gtest/gtest.h"
48 #include "gtest_utils.h"
49 
50 using namespace mozilla;
51 MOZ_MTLOG_MODULE("mtransport")
52 
53 const uint8_t kTlsChangeCipherSpecType = 0x14;
54 const uint8_t kTlsHandshakeType = 0x16;
55 
56 const uint8_t kTlsHandshakeCertificate = 0x0b;
57 const uint8_t kTlsHandshakeServerKeyExchange = 0x0c;
58 
59 const uint8_t kTlsFakeChangeCipherSpec[] = {
60     kTlsChangeCipherSpecType,  // Type
61     0xfe,
62     0xff,  // Version
63     0x00,
64     0x00,
65     0x00,
66     0x00,
67     0x00,
68     0x00,
69     0x00,
70     0x10,  // Fictitious sequence #
71     0x00,
72     0x01,  // Length
73     0x01   // Value
74 };
75 
76 // Layer class which can't be initialized.
77 class TransportLayerDummy : public TransportLayer {
78  public:
TransportLayerDummy(bool allow_init,bool * destroyed)79   TransportLayerDummy(bool allow_init, bool* destroyed)
80       : allow_init_(allow_init), destroyed_(destroyed) {
81     *destroyed_ = false;
82   }
83 
~TransportLayerDummy()84   virtual ~TransportLayerDummy() { *destroyed_ = true; }
85 
InitInternal()86   nsresult InitInternal() override {
87     return allow_init_ ? NS_OK : NS_ERROR_FAILURE;
88   }
89 
SendPacket(MediaPacket & packet)90   TransportResult SendPacket(MediaPacket& packet) override {
91     MOZ_CRASH();  // Should never be called.
92     return 0;
93   }
94 
95   TRANSPORT_LAYER_ID("lossy")
96 
97  private:
98   bool allow_init_;
99   bool* destroyed_;
100 };
101 
102 class Inspector {
103  public:
104   virtual ~Inspector() = default;
105 
106   virtual void Inspect(TransportLayer* layer, const unsigned char* data,
107                        size_t len) = 0;
108 };
109 
110 // Class to simulate various kinds of network lossage
111 class TransportLayerLossy : public TransportLayer {
112  public:
TransportLayerLossy()113   TransportLayerLossy() : loss_mask_(0), packet_(0), inspector_(nullptr) {}
114   ~TransportLayerLossy() = default;
115 
SendPacket(MediaPacket & packet)116   TransportResult SendPacket(MediaPacket& packet) override {
117     MOZ_MTLOG(ML_NOTICE, LAYER_INFO << "SendPacket(" << packet.len() << ")");
118 
119     if (loss_mask_ & (1 << (packet_ % 32))) {
120       MOZ_MTLOG(ML_NOTICE, "Dropping packet");
121       ++packet_;
122       return packet.len();
123     }
124     if (inspector_) {
125       inspector_->Inspect(this, packet.data(), packet.len());
126     }
127 
128     ++packet_;
129 
130     return downward_->SendPacket(packet);
131   }
132 
SetLoss(uint32_t packet)133   void SetLoss(uint32_t packet) { loss_mask_ |= (1 << (packet & 32)); }
134 
SetInspector(UniquePtr<Inspector> inspector)135   void SetInspector(UniquePtr<Inspector> inspector) {
136     inspector_ = std::move(inspector);
137   }
138 
StateChange(TransportLayer * layer,State state)139   void StateChange(TransportLayer* layer, State state) { TL_SET_STATE(state); }
140 
PacketReceived(TransportLayer * layer,MediaPacket & packet)141   void PacketReceived(TransportLayer* layer, MediaPacket& packet) {
142     SignalPacketReceived(this, packet);
143   }
144 
145   TRANSPORT_LAYER_ID("lossy")
146 
147  protected:
WasInserted()148   void WasInserted() override {
149     downward_->SignalPacketReceived.connect(
150         this, &TransportLayerLossy::PacketReceived);
151     downward_->SignalStateChange.connect(this,
152                                          &TransportLayerLossy::StateChange);
153 
154     TL_SET_STATE(downward_->state());
155   }
156 
157  private:
158   uint32_t loss_mask_;
159   uint32_t packet_;
160   UniquePtr<Inspector> inspector_;
161 };
162 
163 // Process DTLS Records
164 #define CHECK_LENGTH(expected)                \
165   do {                                        \
166     EXPECT_GE(remaining(), expected);         \
167     if (remaining() < expected) return false; \
168   } while (0)
169 
170 class TlsParser {
171  public:
TlsParser(const unsigned char * data,size_t len)172   TlsParser(const unsigned char* data, size_t len) : buffer_(), offset_(0) {
173     buffer_.Copy(data, len);
174   }
175 
Read(unsigned char * val)176   bool Read(unsigned char* val) {
177     if (remaining() < 1) {
178       return false;
179     }
180     *val = *ptr();
181     consume(1);
182     return true;
183   }
184 
185   // Read an integral type of specified width.
Read(uint32_t * val,size_t len)186   bool Read(uint32_t* val, size_t len) {
187     if (len > sizeof(uint32_t)) return false;
188 
189     *val = 0;
190 
191     for (size_t i = 0; i < len; ++i) {
192       unsigned char tmp;
193 
194       if (!Read(&tmp)) return false;
195 
196       (*val) = ((*val) << 8) + tmp;
197     }
198 
199     return true;
200   }
201 
Read(unsigned char * val,size_t len)202   bool Read(unsigned char* val, size_t len) {
203     if (remaining() < len) {
204       return false;
205     }
206 
207     if (val) {
208       memcpy(val, ptr(), len);
209     }
210     consume(len);
211 
212     return true;
213   }
214 
215  private:
remaining() const216   size_t remaining() const { return buffer_.len() - offset_; }
ptr() const217   const uint8_t* ptr() const { return buffer_.data() + offset_; }
consume(size_t len)218   void consume(size_t len) { offset_ += len; }
219 
220   MediaPacket buffer_;
221   size_t offset_;
222 };
223 
224 class DtlsRecordParser {
225  public:
DtlsRecordParser(const unsigned char * data,size_t len)226   DtlsRecordParser(const unsigned char* data, size_t len)
227       : buffer_(), offset_(0) {
228     buffer_.Copy(data, len);
229   }
230 
NextRecord(uint8_t * ct,UniquePtr<MediaPacket> * buffer)231   bool NextRecord(uint8_t* ct, UniquePtr<MediaPacket>* buffer) {
232     if (!remaining()) return false;
233 
234     CHECK_LENGTH(13U);
235     const uint8_t* ctp = reinterpret_cast<const uint8_t*>(ptr());
236     consume(11);  // ct + version + length
237 
238     const uint16_t* tmp = reinterpret_cast<const uint16_t*>(ptr());
239     size_t length = ntohs(*tmp);
240     consume(2);
241 
242     CHECK_LENGTH(length);
243     auto db = MakeUnique<MediaPacket>();
244     db->Copy(ptr(), length);
245     consume(length);
246 
247     *ct = *ctp;
248     *buffer = std::move(db);
249 
250     return true;
251   }
252 
253  private:
remaining() const254   size_t remaining() const { return buffer_.len() - offset_; }
ptr() const255   const uint8_t* ptr() const { return buffer_.data() + offset_; }
consume(size_t len)256   void consume(size_t len) { offset_ += len; }
257 
258   MediaPacket buffer_;
259   size_t offset_;
260 };
261 
262 // Inspector that parses out DTLS records and passes
263 // them on.
264 class DtlsRecordInspector : public Inspector {
265  public:
Inspect(TransportLayer * layer,const unsigned char * data,size_t len)266   virtual void Inspect(TransportLayer* layer, const unsigned char* data,
267                        size_t len) {
268     DtlsRecordParser parser(data, len);
269 
270     uint8_t ct;
271     UniquePtr<MediaPacket> buf;
272     while (parser.NextRecord(&ct, &buf)) {
273       OnRecord(layer, ct, buf->data(), buf->len());
274     }
275   }
276 
277   virtual void OnRecord(TransportLayer* layer, uint8_t content_type,
278                         const unsigned char* record, size_t len) = 0;
279 };
280 
281 // Inspector that injects arbitrary packets based on
282 // DTLS records of various types.
283 class DtlsInspectorInjector : public DtlsRecordInspector {
284  public:
DtlsInspectorInjector(uint8_t packet_type,uint8_t handshake_type,const unsigned char * data,size_t len)285   DtlsInspectorInjector(uint8_t packet_type, uint8_t handshake_type,
286                         const unsigned char* data, size_t len)
287       : packet_type_(packet_type), handshake_type_(handshake_type) {
288     packet_.Copy(data, len);
289   }
290 
OnRecord(TransportLayer * layer,uint8_t content_type,const unsigned char * data,size_t len)291   virtual void OnRecord(TransportLayer* layer, uint8_t content_type,
292                         const unsigned char* data, size_t len) {
293     // Only inject once.
294     if (!packet_.data()) {
295       return;
296     }
297 
298     // Check that the first byte is as requested.
299     if (content_type != packet_type_) {
300       return;
301     }
302 
303     if (handshake_type_ != 0xff) {
304       // Check that the packet is plausibly long enough.
305       if (len < 1) {
306         return;
307       }
308 
309       // Check that the handshake type is as requested.
310       if (data[0] != handshake_type_) {
311         return;
312       }
313     }
314 
315     layer->SendPacket(packet_);
316     packet_.Reset();
317   }
318 
319  private:
320   uint8_t packet_type_;
321   uint8_t handshake_type_;
322   MediaPacket packet_;
323 };
324 
325 // Make a copy of the first instance of a message.
326 class DtlsInspectorRecordHandshakeMessage : public DtlsRecordInspector {
327  public:
DtlsInspectorRecordHandshakeMessage(uint8_t handshake_type)328   explicit DtlsInspectorRecordHandshakeMessage(uint8_t handshake_type)
329       : handshake_type_(handshake_type), buffer_() {}
330 
OnRecord(TransportLayer * layer,uint8_t content_type,const unsigned char * data,size_t len)331   virtual void OnRecord(TransportLayer* layer, uint8_t content_type,
332                         const unsigned char* data, size_t len) {
333     // Only do this once.
334     if (buffer_.len()) {
335       return;
336     }
337 
338     // Check that the first byte is as requested.
339     if (content_type != kTlsHandshakeType) {
340       return;
341     }
342 
343     TlsParser parser(data, len);
344     unsigned char message_type;
345     // Read the handshake message type.
346     if (!parser.Read(&message_type)) {
347       return;
348     }
349     if (message_type != handshake_type_) {
350       return;
351     }
352 
353     uint32_t length;
354     if (!parser.Read(&length, 3)) {
355       return;
356     }
357 
358     uint32_t message_seq;
359     if (!parser.Read(&message_seq, 2)) {
360       return;
361     }
362 
363     uint32_t fragment_offset;
364     if (!parser.Read(&fragment_offset, 3)) {
365       return;
366     }
367 
368     uint32_t fragment_length;
369     if (!parser.Read(&fragment_length, 3)) {
370       return;
371     }
372 
373     if ((fragment_offset != 0) || (fragment_length != length)) {
374       // This shouldn't happen because all current tests where we
375       // are using this code don't fragment.
376       return;
377     }
378 
379     UniquePtr<uint8_t[]> buffer(new uint8_t[length]);
380     if (!parser.Read(buffer.get(), length)) {
381       return;
382     }
383     buffer_.Take(std::move(buffer), length);
384   }
385 
buffer()386   const MediaPacket& buffer() { return buffer_; }
387 
388  private:
389   uint8_t handshake_type_;
390   MediaPacket buffer_;
391 };
392 
393 class TlsServerKeyExchangeECDHE {
394  public:
Parse(const unsigned char * data,size_t len)395   bool Parse(const unsigned char* data, size_t len) {
396     TlsParser parser(data, len);
397 
398     uint8_t curve_type;
399     if (!parser.Read(&curve_type)) {
400       return false;
401     }
402 
403     if (curve_type != 3) {  // named_curve
404       return false;
405     }
406 
407     uint32_t named_curve;
408     if (!parser.Read(&named_curve, 2)) {
409       return false;
410     }
411 
412     uint32_t point_length;
413     if (!parser.Read(&point_length, 1)) {
414       return false;
415     }
416 
417     UniquePtr<uint8_t[]> key(new uint8_t[point_length]);
418     if (!parser.Read(key.get(), point_length)) {
419       return false;
420     }
421     public_key_.Take(std::move(key), point_length);
422 
423     return true;
424   }
425 
426   MediaPacket public_key_;
427 };
428 
429 namespace {
430 class TransportTestPeer : public sigslot::has_slots<> {
431  public:
TransportTestPeer(nsCOMPtr<nsIEventTarget> target,std::string name,MtransportTestUtils * utils)432   TransportTestPeer(nsCOMPtr<nsIEventTarget> target, std::string name,
433                     MtransportTestUtils* utils)
434       : name_(name),
435         offerer_(name == "P1"),
436         target_(target),
437         received_packets_(0),
438         received_bytes_(0),
439         flow_(new TransportFlow(name)),
440         loopback_(new TransportLayerLoopback()),
441         logging_(new TransportLayerLogging()),
442         lossy_(new TransportLayerLossy()),
443         dtls_(new TransportLayerDtls()),
444         identity_(DtlsIdentity::Generate()),
445         ice_ctx_(),
446         streams_(),
447         peer_(nullptr),
448         gathering_complete_(false),
449         digest_("sha-1"),
450         enabled_cipersuites_(),
451         disabled_cipersuites_(),
452         test_utils_(utils) {
453     NrIceCtx::InitializeGlobals(NrIceCtx::GlobalConfig());
454     ice_ctx_ = NrIceCtx::Create(name, NrIceCtx::Config());
455     std::vector<NrIceStunServer> stun_servers;
456     UniquePtr<NrIceStunServer> server(NrIceStunServer::Create(
457         std::string((char*)"stun.services.mozilla.com"), 3478));
458     stun_servers.push_back(*server);
459     EXPECT_TRUE(NS_SUCCEEDED(ice_ctx_->SetStunServers(stun_servers)));
460 
461     dtls_->SetIdentity(identity_);
462     dtls_->SetRole(offerer_ ? TransportLayerDtls::SERVER
463                             : TransportLayerDtls::CLIENT);
464 
465     nsresult res = identity_->ComputeFingerprint(&digest_);
466     EXPECT_TRUE(NS_SUCCEEDED(res));
467     EXPECT_EQ(20u, digest_.value_.size());
468   }
469 
~TransportTestPeer()470   ~TransportTestPeer() {
471     test_utils_->sts_target()->Dispatch(
472         WrapRunnable(this, &TransportTestPeer::DestroyFlow), NS_DISPATCH_SYNC);
473   }
474 
DestroyFlow()475   void DestroyFlow() {
476     disconnect_all();
477     if (flow_) {
478       loopback_->Disconnect();
479       flow_ = nullptr;
480     }
481     ice_ctx_->Destroy();
482     ice_ctx_ = nullptr;
483     streams_.clear();
484   }
485 
DisconnectDestroyFlow()486   void DisconnectDestroyFlow() {
487     test_utils_->sts_target()->Dispatch(
488         NS_NewRunnableFunction(
489             __func__,
490             [this] {
491               loopback_->Disconnect();
492               disconnect_all();  // Disconnect from the signals;
493               flow_ = nullptr;
494             }),
495         NS_DISPATCH_SYNC);
496   }
497 
SetDtlsAllowAll()498   void SetDtlsAllowAll() {
499     nsresult res = dtls_->SetVerificationAllowAll();
500     ASSERT_TRUE(NS_SUCCEEDED(res));
501   }
502 
SetAlpn(std::string str,bool withDefault,std::string extra="")503   void SetAlpn(std::string str, bool withDefault, std::string extra = "") {
504     std::set<std::string> alpn;
505     alpn.insert(str);  // the one we want to select
506     if (!extra.empty()) {
507       alpn.insert(extra);
508     }
509     nsresult res = dtls_->SetAlpn(alpn, withDefault ? str : "");
510     ASSERT_EQ(NS_OK, res);
511   }
512 
GetAlpn() const513   const std::string& GetAlpn() const { return dtls_->GetNegotiatedAlpn(); }
514 
SetDtlsPeer(TransportTestPeer * peer,int digests,unsigned int damage)515   void SetDtlsPeer(TransportTestPeer* peer, int digests, unsigned int damage) {
516     unsigned int mask = 1;
517 
518     for (int i = 0; i < digests; i++) {
519       DtlsDigest digest_to_set(peer->digest_);
520 
521       if (damage & mask) digest_to_set.value_.data()[0]++;
522 
523       nsresult res = dtls_->SetVerificationDigest(digest_to_set);
524 
525       ASSERT_TRUE(NS_SUCCEEDED(res));
526 
527       mask <<= 1;
528     }
529   }
530 
SetupSrtp()531   void SetupSrtp() {
532     std::vector<uint16_t> srtp_ciphers =
533         TransportLayerDtls::GetDefaultSrtpCiphers();
534     SetSrtpCiphers(srtp_ciphers);
535   }
536 
SetSrtpCiphers(std::vector<uint16_t> & srtp_ciphers)537   void SetSrtpCiphers(std::vector<uint16_t>& srtp_ciphers) {
538     ASSERT_TRUE(NS_SUCCEEDED(dtls_->SetSrtpCiphers(srtp_ciphers)));
539   }
540 
ConnectSocket_s(TransportTestPeer * peer)541   void ConnectSocket_s(TransportTestPeer* peer) {
542     nsresult res;
543     res = loopback_->Init();
544     ASSERT_EQ((nsresult)NS_OK, res);
545 
546     loopback_->Connect(peer->loopback_);
547     ASSERT_EQ((nsresult)NS_OK, loopback_->Init());
548     ASSERT_EQ((nsresult)NS_OK, logging_->Init());
549     ASSERT_EQ((nsresult)NS_OK, lossy_->Init());
550     ASSERT_EQ((nsresult)NS_OK, dtls_->Init());
551     dtls_->Chain(lossy_);
552     lossy_->Chain(logging_);
553     logging_->Chain(loopback_);
554 
555     flow_->PushLayer(loopback_);
556     flow_->PushLayer(logging_);
557     flow_->PushLayer(lossy_);
558     flow_->PushLayer(dtls_);
559 
560     if (dtls_->state() != TransportLayer::TS_ERROR) {
561       // Don't execute these blocks if DTLS didn't initialize.
562       TweakCiphers(dtls_->internal_fd());
563       if (post_setup_) {
564         post_setup_(dtls_->internal_fd());
565       }
566     }
567 
568     dtls_->SignalPacketReceived.connect(this,
569                                         &TransportTestPeer::PacketReceived);
570   }
571 
TweakCiphers(PRFileDesc * fd)572   void TweakCiphers(PRFileDesc* fd) {
573     for (unsigned short& enabled_cipersuite : enabled_cipersuites_) {
574       SSL_CipherPrefSet(fd, enabled_cipersuite, PR_TRUE);
575     }
576     for (unsigned short& disabled_cipersuite : disabled_cipersuites_) {
577       SSL_CipherPrefSet(fd, disabled_cipersuite, PR_FALSE);
578     }
579   }
580 
ConnectSocket(TransportTestPeer * peer)581   void ConnectSocket(TransportTestPeer* peer) {
582     RUN_ON_THREAD(test_utils_->sts_target(),
583                   WrapRunnable(this, &TransportTestPeer::ConnectSocket_s, peer),
584                   NS_DISPATCH_SYNC);
585   }
586 
InitIce_s()587   nsresult InitIce_s() {
588     nsresult rv = ice_->Init();
589     NS_ENSURE_SUCCESS(rv, rv);
590     rv = dtls_->Init();
591     NS_ENSURE_SUCCESS(rv, rv);
592     dtls_->Chain(ice_);
593     flow_->PushLayer(ice_);
594     flow_->PushLayer(dtls_);
595     return NS_OK;
596   }
597 
InitIce()598   void InitIce() {
599     nsresult res;
600 
601     // Attach our slots
602     ice_ctx_->SignalGatheringStateChange.connect(
603         this, &TransportTestPeer::GatheringStateChange);
604 
605     char name[100];
606     snprintf(name, sizeof(name), "%s:stream%d", name_.c_str(),
607              (int)streams_.size());
608 
609     // Create the media stream
610     RefPtr<NrIceMediaStream> stream = ice_ctx_->CreateStream(name, name, 1);
611 
612     ASSERT_TRUE(stream != nullptr);
613     stream->SetIceCredentials("ufrag", "pass");
614     streams_.push_back(stream);
615 
616     // Listen for candidates
617     stream->SignalCandidate.connect(this, &TransportTestPeer::GotCandidate);
618 
619     // Create the transport layer
620     ice_ = new TransportLayerIce();
621     ice_->SetParameters(stream, 1);
622 
623     test_utils_->sts_target()->Dispatch(
624         WrapRunnableRet(&res, this, &TransportTestPeer::InitIce_s),
625         NS_DISPATCH_SYNC);
626 
627     ASSERT_EQ((nsresult)NS_OK, res);
628 
629     // Listen for media events
630     dtls_->SignalPacketReceived.connect(this,
631                                         &TransportTestPeer::PacketReceived);
632     dtls_->SignalStateChange.connect(this, &TransportTestPeer::StateChanged);
633 
634     // Start gathering
635     test_utils_->sts_target()->Dispatch(
636         WrapRunnableRet(&res, ice_ctx_, &NrIceCtx::StartGathering, false,
637                         false),
638         NS_DISPATCH_SYNC);
639     ASSERT_TRUE(NS_SUCCEEDED(res));
640   }
641 
ConnectIce(TransportTestPeer * peer)642   void ConnectIce(TransportTestPeer* peer) {
643     peer_ = peer;
644 
645     // If gathering is already complete, push the candidates over
646     if (gathering_complete_) GatheringComplete();
647   }
648 
649   // New candidate
GotCandidate(NrIceMediaStream * stream,const std::string & candidate,const std::string & ufrag,const std::string & mdns_addr,const std::string & actual_addr)650   void GotCandidate(NrIceMediaStream* stream, const std::string& candidate,
651                     const std::string& ufrag, const std::string& mdns_addr,
652                     const std::string& actual_addr) {
653     std::cerr << "Got candidate " << candidate << " (ufrag=" << ufrag << ")"
654               << std::endl;
655   }
656 
GatheringStateChange(NrIceCtx * ctx,NrIceCtx::GatheringState state)657   void GatheringStateChange(NrIceCtx* ctx, NrIceCtx::GatheringState state) {
658     (void)ctx;
659     if (state == NrIceCtx::ICE_CTX_GATHER_COMPLETE) {
660       GatheringComplete();
661     }
662   }
663 
664   // Gathering complete, so send our candidates and start
665   // connecting on the other peer.
GatheringComplete()666   void GatheringComplete() {
667     nsresult res;
668 
669     // Don't send to the other side
670     if (!peer_) {
671       gathering_complete_ = true;
672       return;
673     }
674 
675     // First send attributes
676     test_utils_->sts_target()->Dispatch(
677         WrapRunnableRet(&res, peer_->ice_ctx_, &NrIceCtx::ParseGlobalAttributes,
678                         ice_ctx_->GetGlobalAttributes()),
679         NS_DISPATCH_SYNC);
680     ASSERT_TRUE(NS_SUCCEEDED(res));
681 
682     for (size_t i = 0; i < streams_.size(); ++i) {
683       test_utils_->sts_target()->Dispatch(
684           WrapRunnableRet(&res, peer_->streams_[i],
685                           &NrIceMediaStream::ConnectToPeer, "ufrag", "pass",
686                           streams_[i]->GetAttributes()),
687           NS_DISPATCH_SYNC);
688 
689       ASSERT_TRUE(NS_SUCCEEDED(res));
690     }
691 
692     // Start checks on the other peer.
693     test_utils_->sts_target()->Dispatch(
694         WrapRunnableRet(&res, peer_->ice_ctx_, &NrIceCtx::StartChecks),
695         NS_DISPATCH_SYNC);
696     ASSERT_TRUE(NS_SUCCEEDED(res));
697   }
698 
699   // WrapRunnable/lambda and move semantics (MediaPacket is not copyable) don't
700   // get along yet, so we need a wrapper. Gross.
SendPacketWrapper(TransportLayer * layer,MediaPacket * packet)701   static TransportResult SendPacketWrapper(TransportLayer* layer,
702                                            MediaPacket* packet) {
703     return layer->SendPacket(*packet);
704   }
705 
SendPacket(MediaPacket & packet)706   TransportResult SendPacket(MediaPacket& packet) {
707     TransportResult ret;
708 
709     test_utils_->sts_target()->Dispatch(
710         WrapRunnableNMRet(&ret, &TransportTestPeer::SendPacketWrapper, dtls_,
711                           &packet),
712         NS_DISPATCH_SYNC);
713 
714     return ret;
715   }
716 
StateChanged(TransportLayer * layer,TransportLayer::State state)717   void StateChanged(TransportLayer* layer, TransportLayer::State state) {
718     if (state == TransportLayer::TS_OPEN) {
719       std::cerr << "Now connected" << std::endl;
720     }
721   }
722 
PacketReceived(TransportLayer * layer,MediaPacket & packet)723   void PacketReceived(TransportLayer* layer, MediaPacket& packet) {
724     std::cerr << "Received " << packet.len() << " bytes" << std::endl;
725     ++received_packets_;
726     received_bytes_ += packet.len();
727   }
728 
SetLoss(uint32_t loss)729   void SetLoss(uint32_t loss) { lossy_->SetLoss(loss); }
730 
SetCombinePackets(bool combine)731   void SetCombinePackets(bool combine) { loopback_->CombinePackets(combine); }
732 
SetInspector(UniquePtr<Inspector> inspector)733   void SetInspector(UniquePtr<Inspector> inspector) {
734     lossy_->SetInspector(std::move(inspector));
735   }
736 
SetInspector(Inspector * in)737   void SetInspector(Inspector* in) {
738     UniquePtr<Inspector> inspector(in);
739 
740     lossy_->SetInspector(std::move(inspector));
741   }
742 
SetCipherSuiteChanges(const std::vector<uint16_t> & enableThese,const std::vector<uint16_t> & disableThese)743   void SetCipherSuiteChanges(const std::vector<uint16_t>& enableThese,
744                              const std::vector<uint16_t>& disableThese) {
745     disabled_cipersuites_ = disableThese;
746     enabled_cipersuites_ = enableThese;
747   }
748 
SetPostSetup(const std::function<void (PRFileDesc *)> & setup)749   void SetPostSetup(const std::function<void(PRFileDesc*)>& setup) {
750     post_setup_ = std::move(setup);
751   }
752 
state()753   TransportLayer::State state() {
754     TransportLayer::State tstate;
755 
756     RUN_ON_THREAD(test_utils_->sts_target(),
757                   WrapRunnableRet(&tstate, dtls_, &TransportLayer::state));
758 
759     return tstate;
760   }
761 
connected()762   bool connected() { return state() == TransportLayer::TS_OPEN; }
763 
failed()764   bool failed() { return state() == TransportLayer::TS_ERROR; }
765 
receivedPackets()766   size_t receivedPackets() { return received_packets_; }
767 
receivedBytes()768   size_t receivedBytes() { return received_bytes_; }
769 
cipherSuite() const770   uint16_t cipherSuite() const {
771     nsresult rv;
772     uint16_t cipher;
773     RUN_ON_THREAD(
774         test_utils_->sts_target(),
775         WrapRunnableRet(&rv, dtls_, &TransportLayerDtls::GetCipherSuite,
776                         &cipher));
777 
778     if (NS_FAILED(rv)) {
779       return TLS_NULL_WITH_NULL_NULL;  // i.e., not good
780     }
781     return cipher;
782   }
783 
srtpCipher() const784   uint16_t srtpCipher() const {
785     nsresult rv;
786     uint16_t cipher;
787     RUN_ON_THREAD(test_utils_->sts_target(),
788                   WrapRunnableRet(&rv, dtls_,
789                                   &TransportLayerDtls::GetSrtpCipher, &cipher));
790     if (NS_FAILED(rv)) {
791       return 0;  // the SRTP equivalent of TLS_NULL_WITH_NULL_NULL
792     }
793     return cipher;
794   }
795 
796  private:
797   std::string name_;
798   bool offerer_;
799   nsCOMPtr<nsIEventTarget> target_;
800   size_t received_packets_;
801   size_t received_bytes_;
802   RefPtr<TransportFlow> flow_;
803   TransportLayerLoopback* loopback_;
804   TransportLayerLogging* logging_;
805   TransportLayerLossy* lossy_;
806   TransportLayerDtls* dtls_;
807   TransportLayerIce* ice_;
808   RefPtr<DtlsIdentity> identity_;
809   RefPtr<NrIceCtx> ice_ctx_;
810   std::vector<RefPtr<NrIceMediaStream> > streams_;
811   TransportTestPeer* peer_;
812   bool gathering_complete_;
813   DtlsDigest digest_;
814   std::vector<uint16_t> enabled_cipersuites_;
815   std::vector<uint16_t> disabled_cipersuites_;
816   MtransportTestUtils* test_utils_;
817   std::function<void(PRFileDesc* fd)> post_setup_ = nullptr;
818 };
819 
820 class TransportTest : public MtransportTest {
821  public:
TransportTest()822   TransportTest() {
823     fds_[0] = nullptr;
824     fds_[1] = nullptr;
825     p1_ = nullptr;
826     p2_ = nullptr;
827   }
828 
TearDown()829   void TearDown() override {
830     delete p1_;
831     delete p2_;
832 
833     //    Can't detach these
834     //    PR_Close(fds_[0]);
835     //    PR_Close(fds_[1]);
836     MtransportTest::TearDown();
837   }
838 
DestroyPeerFlows()839   void DestroyPeerFlows() {
840     p1_->DisconnectDestroyFlow();
841     p2_->DisconnectDestroyFlow();
842   }
843 
SetUp()844   void SetUp() override {
845     MtransportTest::SetUp();
846 
847     nsresult rv;
848     target_ = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv);
849     ASSERT_TRUE(NS_SUCCEEDED(rv));
850 
851     Reset();
852   }
853 
Reset()854   void Reset() {
855     if (p1_) {
856       delete p1_;
857     }
858     if (p2_) {
859       delete p2_;
860     }
861     p1_ = new TransportTestPeer(target_, "P1", test_utils_);
862     p2_ = new TransportTestPeer(target_, "P2", test_utils_);
863   }
864 
SetupSrtp()865   void SetupSrtp() {
866     p1_->SetupSrtp();
867     p2_->SetupSrtp();
868   }
869 
SetDtlsPeer(int digests=1,unsigned int damage=0)870   void SetDtlsPeer(int digests = 1, unsigned int damage = 0) {
871     p1_->SetDtlsPeer(p2_, digests, damage);
872     p2_->SetDtlsPeer(p1_, digests, damage);
873   }
874 
SetDtlsAllowAll()875   void SetDtlsAllowAll() {
876     p1_->SetDtlsAllowAll();
877     p2_->SetDtlsAllowAll();
878   }
879 
SetAlpn(std::string first,std::string second,bool withDefaults=true)880   void SetAlpn(std::string first, std::string second,
881                bool withDefaults = true) {
882     if (!first.empty()) {
883       p1_->SetAlpn(first, withDefaults, "bogus");
884     }
885     if (!second.empty()) {
886       p2_->SetAlpn(second, withDefaults);
887     }
888   }
889 
CheckAlpn(std::string first,std::string second)890   void CheckAlpn(std::string first, std::string second) {
891     ASSERT_EQ(first, p1_->GetAlpn());
892     ASSERT_EQ(second, p2_->GetAlpn());
893   }
894 
ConnectSocket()895   void ConnectSocket() {
896     ConnectSocketInternal();
897     ASSERT_TRUE_WAIT(p1_->connected(), 10000);
898     ASSERT_TRUE_WAIT(p2_->connected(), 10000);
899 
900     ASSERT_EQ(p1_->cipherSuite(), p2_->cipherSuite());
901     ASSERT_EQ(p1_->srtpCipher(), p2_->srtpCipher());
902   }
903 
ConnectSocketExpectFail()904   void ConnectSocketExpectFail() {
905     ConnectSocketInternal();
906     ASSERT_TRUE_WAIT(p1_->failed(), 10000);
907     ASSERT_TRUE_WAIT(p2_->failed(), 10000);
908   }
909 
ConnectSocketExpectState(TransportLayer::State s1,TransportLayer::State s2)910   void ConnectSocketExpectState(TransportLayer::State s1,
911                                 TransportLayer::State s2) {
912     ConnectSocketInternal();
913     ASSERT_EQ_WAIT(s1, p1_->state(), 10000);
914     ASSERT_EQ_WAIT(s2, p2_->state(), 10000);
915   }
916 
ConnectIce()917   void ConnectIce() {
918     p1_->InitIce();
919     p2_->InitIce();
920     p1_->ConnectIce(p2_);
921     p2_->ConnectIce(p1_);
922     ASSERT_TRUE_WAIT(p1_->connected(), 10000);
923     ASSERT_TRUE_WAIT(p2_->connected(), 10000);
924   }
925 
TransferTest(size_t count,size_t bytes=1024)926   void TransferTest(size_t count, size_t bytes = 1024) {
927     unsigned char buf[bytes];
928 
929     for (size_t i = 0; i < count; ++i) {
930       memset(buf, count & 0xff, sizeof(buf));
931       MediaPacket packet;
932       packet.Copy(buf, sizeof(buf));
933       TransportResult rv = p1_->SendPacket(packet);
934       ASSERT_TRUE(rv > 0);
935     }
936 
937     std::cerr << "Received == " << p2_->receivedPackets() << " packets"
938               << std::endl;
939     ASSERT_TRUE_WAIT(count == p2_->receivedPackets(), 10000);
940     ASSERT_TRUE((count * sizeof(buf)) == p2_->receivedBytes());
941   }
942 
943  protected:
ConnectSocketInternal()944   void ConnectSocketInternal() {
945     test_utils_->sts_target()->Dispatch(
946         WrapRunnable(p1_, &TransportTestPeer::ConnectSocket, p2_),
947         NS_DISPATCH_SYNC);
948     test_utils_->sts_target()->Dispatch(
949         WrapRunnable(p2_, &TransportTestPeer::ConnectSocket, p1_),
950         NS_DISPATCH_SYNC);
951   }
952 
953   PRFileDesc* fds_[2];
954   TransportTestPeer* p1_;
955   TransportTestPeer* p2_;
956   nsCOMPtr<nsIEventTarget> target_;
957 };
958 
TEST_F(TransportTest,TestNoDtlsVerificationSettings)959 TEST_F(TransportTest, TestNoDtlsVerificationSettings) {
960   ConnectSocketExpectFail();
961 }
962 
DisableChaCha(TransportTestPeer * peer)963 static void DisableChaCha(TransportTestPeer* peer) {
964   // On ARM, ChaCha20Poly1305 might be preferred; disable it for the tests that
965   // want to check the cipher suite.  It doesn't matter which peer disables the
966   // suite, disabling on either side has the same effect.
967   std::vector<uint16_t> chachaSuites;
968   chachaSuites.push_back(TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256);
969   chachaSuites.push_back(TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256);
970   peer->SetCipherSuiteChanges(std::vector<uint16_t>(), chachaSuites);
971 }
972 
TEST_F(TransportTest,TestConnect)973 TEST_F(TransportTest, TestConnect) {
974   SetDtlsPeer();
975   DisableChaCha(p1_);
976   ConnectSocket();
977 
978   // check that we got the right suite
979   ASSERT_EQ(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, p1_->cipherSuite());
980 
981   // no SRTP on this one
982   ASSERT_EQ(0, p1_->srtpCipher());
983 }
984 
TEST_F(TransportTest,TestConnectSrtp)985 TEST_F(TransportTest, TestConnectSrtp) {
986   SetupSrtp();
987   SetDtlsPeer();
988   DisableChaCha(p2_);
989   ConnectSocket();
990 
991   ASSERT_EQ(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, p1_->cipherSuite());
992 
993   // SRTP is on with default value
994   ASSERT_EQ(kDtlsSrtpAeadAes128Gcm, p1_->srtpCipher());
995 }
996 
TEST_F(TransportTest,TestConnectDestroyFlowsMainThread)997 TEST_F(TransportTest, TestConnectDestroyFlowsMainThread) {
998   SetDtlsPeer();
999   ConnectSocket();
1000   DestroyPeerFlows();
1001 }
1002 
TEST_F(TransportTest,TestConnectAllowAll)1003 TEST_F(TransportTest, TestConnectAllowAll) {
1004   SetDtlsAllowAll();
1005   ConnectSocket();
1006 }
1007 
TEST_F(TransportTest,TestConnectAlpn)1008 TEST_F(TransportTest, TestConnectAlpn) {
1009   SetDtlsPeer();
1010   SetAlpn("a", "a");
1011   ConnectSocket();
1012   CheckAlpn("a", "a");
1013 }
1014 
TEST_F(TransportTest,TestConnectAlpnMismatch)1015 TEST_F(TransportTest, TestConnectAlpnMismatch) {
1016   SetDtlsPeer();
1017   SetAlpn("something", "different");
1018   ConnectSocketExpectFail();
1019 }
1020 
TEST_F(TransportTest,TestConnectAlpnServerDefault)1021 TEST_F(TransportTest, TestConnectAlpnServerDefault) {
1022   SetDtlsPeer();
1023   SetAlpn("def", "");
1024   // server allows default, client doesn't support
1025   ConnectSocket();
1026   CheckAlpn("def", "");
1027 }
1028 
TEST_F(TransportTest,TestConnectAlpnClientDefault)1029 TEST_F(TransportTest, TestConnectAlpnClientDefault) {
1030   SetDtlsPeer();
1031   SetAlpn("", "clientdef");
1032   // client allows default, but server will ignore the extension
1033   ConnectSocket();
1034   CheckAlpn("", "clientdef");
1035 }
1036 
TEST_F(TransportTest,TestConnectClientNoAlpn)1037 TEST_F(TransportTest, TestConnectClientNoAlpn) {
1038   SetDtlsPeer();
1039   // Here the server has ALPN, but no default is allowed.
1040   // Reminder: p1 == server, p2 == client
1041   SetAlpn("server-nodefault", "", false);
1042   // The server doesn't see the extension, so negotiates without it.
1043   // But then the server is forced to close when it discovers that ALPN wasn't
1044   // negotiated; the client sees a close.
1045   ConnectSocketExpectState(TransportLayer::TS_ERROR, TransportLayer::TS_CLOSED);
1046 }
1047 
TEST_F(TransportTest,TestConnectServerNoAlpn)1048 TEST_F(TransportTest, TestConnectServerNoAlpn) {
1049   SetDtlsPeer();
1050   SetAlpn("", "client-nodefault", false);
1051   // The client aborts; the server doesn't realize this is a problem and just
1052   // sees the close.
1053   ConnectSocketExpectState(TransportLayer::TS_CLOSED, TransportLayer::TS_ERROR);
1054 }
1055 
TEST_F(TransportTest,TestConnectNoDigest)1056 TEST_F(TransportTest, TestConnectNoDigest) {
1057   SetDtlsPeer(0, 0);
1058 
1059   ConnectSocketExpectFail();
1060 }
1061 
TEST_F(TransportTest,TestConnectBadDigest)1062 TEST_F(TransportTest, TestConnectBadDigest) {
1063   SetDtlsPeer(1, 1);
1064 
1065   ConnectSocketExpectFail();
1066 }
1067 
TEST_F(TransportTest,TestConnectTwoDigests)1068 TEST_F(TransportTest, TestConnectTwoDigests) {
1069   SetDtlsPeer(2, 0);
1070 
1071   ConnectSocket();
1072 }
1073 
TEST_F(TransportTest,TestConnectTwoDigestsFirstBad)1074 TEST_F(TransportTest, TestConnectTwoDigestsFirstBad) {
1075   SetDtlsPeer(2, 1);
1076 
1077   ConnectSocket();
1078 }
1079 
TEST_F(TransportTest,TestConnectTwoDigestsSecondBad)1080 TEST_F(TransportTest, TestConnectTwoDigestsSecondBad) {
1081   SetDtlsPeer(2, 2);
1082 
1083   ConnectSocket();
1084 }
1085 
TEST_F(TransportTest,TestConnectTwoDigestsBothBad)1086 TEST_F(TransportTest, TestConnectTwoDigestsBothBad) {
1087   SetDtlsPeer(2, 3);
1088 
1089   ConnectSocketExpectFail();
1090 }
1091 
TEST_F(TransportTest,TestConnectInjectCCS)1092 TEST_F(TransportTest, TestConnectInjectCCS) {
1093   SetDtlsPeer();
1094   p2_->SetInspector(MakeUnique<DtlsInspectorInjector>(
1095       kTlsHandshakeType, kTlsHandshakeCertificate, kTlsFakeChangeCipherSpec,
1096       sizeof(kTlsFakeChangeCipherSpec)));
1097 
1098   ConnectSocket();
1099 }
1100 
TEST_F(TransportTest,TestConnectVerifyNewECDHE)1101 TEST_F(TransportTest, TestConnectVerifyNewECDHE) {
1102   SetDtlsPeer();
1103   DtlsInspectorRecordHandshakeMessage* i1 =
1104       new DtlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
1105   p1_->SetInspector(i1);
1106   ConnectSocket();
1107   TlsServerKeyExchangeECDHE dhe1;
1108   ASSERT_TRUE(dhe1.Parse(i1->buffer().data(), i1->buffer().len()));
1109 
1110   Reset();
1111   SetDtlsPeer();
1112   DtlsInspectorRecordHandshakeMessage* i2 =
1113       new DtlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
1114   p1_->SetInspector(i2);
1115   ConnectSocket();
1116   TlsServerKeyExchangeECDHE dhe2;
1117   ASSERT_TRUE(dhe2.Parse(i2->buffer().data(), i2->buffer().len()));
1118 
1119   // Now compare these two to see if they are the same.
1120   ASSERT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) &&
1121                (!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(),
1122                         dhe1.public_key_.len())));
1123 }
1124 
TEST_F(TransportTest,TestConnectVerifyReusedECDHE)1125 TEST_F(TransportTest, TestConnectVerifyReusedECDHE) {
1126   auto set_reuse_ecdhe_key = [](PRFileDesc* fd) {
1127     // TransportLayerDtls automatically sets this pref to false
1128     // so set it back for test.
1129     // This is pretty gross. Dig directly into the NSS FD. The problem
1130     // is that we are testing a feature which TransaportLayerDtls doesn't
1131     // expose.
1132     SECStatus rv = SSL_OptionSet(fd, SSL_REUSE_SERVER_ECDHE_KEY, PR_TRUE);
1133     ASSERT_EQ(SECSuccess, rv);
1134   };
1135 
1136   SetDtlsPeer();
1137   DtlsInspectorRecordHandshakeMessage* i1 =
1138       new DtlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
1139   p1_->SetInspector(i1);
1140   p1_->SetPostSetup(set_reuse_ecdhe_key);
1141   ConnectSocket();
1142   TlsServerKeyExchangeECDHE dhe1;
1143   ASSERT_TRUE(dhe1.Parse(i1->buffer().data(), i1->buffer().len()));
1144 
1145   Reset();
1146   SetDtlsPeer();
1147   DtlsInspectorRecordHandshakeMessage* i2 =
1148       new DtlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
1149 
1150   p1_->SetInspector(i2);
1151   p1_->SetPostSetup(set_reuse_ecdhe_key);
1152 
1153   ConnectSocket();
1154   TlsServerKeyExchangeECDHE dhe2;
1155   ASSERT_TRUE(dhe2.Parse(i2->buffer().data(), i2->buffer().len()));
1156 
1157   // Now compare these two to see if they are the same.
1158   ASSERT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len());
1159   ASSERT_TRUE(!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(),
1160                       dhe1.public_key_.len()));
1161 }
1162 
TEST_F(TransportTest,TestTransfer)1163 TEST_F(TransportTest, TestTransfer) {
1164   SetDtlsPeer();
1165   ConnectSocket();
1166   TransferTest(1);
1167 }
1168 
TEST_F(TransportTest,TestTransferMaxSize)1169 TEST_F(TransportTest, TestTransferMaxSize) {
1170   SetDtlsPeer();
1171   ConnectSocket();
1172   /* transportlayerdtls uses a 9216 bytes buffer - as this test uses the
1173    * loopback implementation it does not have to take into account the extra
1174    * bytes added by the DTLS layer below. */
1175   TransferTest(1, 9216);
1176 }
1177 
TEST_F(TransportTest,TestTransferMultiple)1178 TEST_F(TransportTest, TestTransferMultiple) {
1179   SetDtlsPeer();
1180   ConnectSocket();
1181   TransferTest(3);
1182 }
1183 
TEST_F(TransportTest,TestTransferCombinedPackets)1184 TEST_F(TransportTest, TestTransferCombinedPackets) {
1185   SetDtlsPeer();
1186   ConnectSocket();
1187   p2_->SetCombinePackets(true);
1188   TransferTest(3);
1189 }
1190 
TEST_F(TransportTest,TestConnectLoseFirst)1191 TEST_F(TransportTest, TestConnectLoseFirst) {
1192   SetDtlsPeer();
1193   p1_->SetLoss(0);
1194   ConnectSocket();
1195   TransferTest(1);
1196 }
1197 
TEST_F(TransportTest,TestConnectIce)1198 TEST_F(TransportTest, TestConnectIce) {
1199   SetDtlsPeer();
1200   ConnectIce();
1201 }
1202 
TEST_F(TransportTest,TestTransferIceMaxSize)1203 TEST_F(TransportTest, TestTransferIceMaxSize) {
1204   SetDtlsPeer();
1205   ConnectIce();
1206   /* nICEr and transportlayerdtls both use 9216 bytes buffers. But the DTLS
1207    * layer add extra bytes to the packet, which size depends on chosen cipher
1208    * etc. Sending more then 9216 bytes works, but on the receiving side the call
1209    * to PR_recvfrom() will truncate any packet bigger then nICEr's buffer size
1210    * of 9216 bytes, which then results in the DTLS layer discarding the packet.
1211    * Therefore we leave some headroom (according to
1212    * https://bugzilla.mozilla.org/show_bug.cgi?id=1214269#c29 256 bytes should
1213    * be save choice) here for the DTLS bytes to make it safely into the
1214    * receiving buffer in nICEr. */
1215   TransferTest(1, 8960);
1216 }
1217 
TEST_F(TransportTest,TestTransferIceMultiple)1218 TEST_F(TransportTest, TestTransferIceMultiple) {
1219   SetDtlsPeer();
1220   ConnectIce();
1221   TransferTest(3);
1222 }
1223 
TEST_F(TransportTest,TestTransferIceCombinedPackets)1224 TEST_F(TransportTest, TestTransferIceCombinedPackets) {
1225   SetDtlsPeer();
1226   ConnectIce();
1227   p2_->SetCombinePackets(true);
1228   TransferTest(3);
1229 }
1230 
1231 // test the default configuration against a peer that supports only
1232 // one of the mandatory-to-implement suites, which should succeed
ConfigureOneCipher(TransportTestPeer * peer,uint16_t suite)1233 static void ConfigureOneCipher(TransportTestPeer* peer, uint16_t suite) {
1234   std::vector<uint16_t> justOne;
1235   justOne.push_back(suite);
1236   std::vector<uint16_t> everythingElse(
1237       SSL_GetImplementedCiphers(),
1238       SSL_GetImplementedCiphers() + SSL_GetNumImplementedCiphers());
1239   std::remove(everythingElse.begin(), everythingElse.end(), suite);
1240   peer->SetCipherSuiteChanges(justOne, everythingElse);
1241 }
1242 
TEST_F(TransportTest,TestCipherMismatch)1243 TEST_F(TransportTest, TestCipherMismatch) {
1244   SetDtlsPeer();
1245   ConfigureOneCipher(p1_, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256);
1246   ConfigureOneCipher(p2_, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA);
1247   ConnectSocketExpectFail();
1248 }
1249 
TEST_F(TransportTest,TestCipherMandatoryOnlyGcm)1250 TEST_F(TransportTest, TestCipherMandatoryOnlyGcm) {
1251   SetDtlsPeer();
1252   ConfigureOneCipher(p1_, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256);
1253   ConnectSocket();
1254   ASSERT_EQ(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, p1_->cipherSuite());
1255 }
1256 
TEST_F(TransportTest,TestCipherMandatoryOnlyCbc)1257 TEST_F(TransportTest, TestCipherMandatoryOnlyCbc) {
1258   SetDtlsPeer();
1259   ConfigureOneCipher(p1_, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA);
1260   ConnectSocket();
1261   ASSERT_EQ(TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, p1_->cipherSuite());
1262 }
1263 
TEST_F(TransportTest,TestSrtpMismatch)1264 TEST_F(TransportTest, TestSrtpMismatch) {
1265   std::vector<uint16_t> setA;
1266   setA.push_back(kDtlsSrtpAes128CmHmacSha1_80);
1267   std::vector<uint16_t> setB;
1268   setB.push_back(kDtlsSrtpAes128CmHmacSha1_32);
1269 
1270   p1_->SetSrtpCiphers(setA);
1271   p2_->SetSrtpCiphers(setB);
1272   SetDtlsPeer();
1273   ConnectSocketExpectFail();
1274 
1275   ASSERT_EQ(0, p1_->srtpCipher());
1276   ASSERT_EQ(0, p2_->srtpCipher());
1277 }
1278 
NoopXtnHandler(PRFileDesc * fd,SSLHandshakeType message,const uint8_t * data,unsigned int len,SSLAlertDescription * alert,void * arg)1279 static SECStatus NoopXtnHandler(PRFileDesc* fd, SSLHandshakeType message,
1280                                 const uint8_t* data, unsigned int len,
1281                                 SSLAlertDescription* alert, void* arg) {
1282   return SECSuccess;
1283 }
1284 
WriteFixedXtn(PRFileDesc * fd,SSLHandshakeType message,uint8_t * data,unsigned int * len,unsigned int max_len,void * arg)1285 static PRBool WriteFixedXtn(PRFileDesc* fd, SSLHandshakeType message,
1286                             uint8_t* data, unsigned int* len,
1287                             unsigned int max_len, void* arg) {
1288   // When we enable TLS 1.3, change ssl_hs_server_hello here to
1289   // ssl_hs_encrypted_extensions.  At the same time, add a test that writes to
1290   // ssl_hs_server_hello, which should fail.
1291   if (message != ssl_hs_client_hello && message != ssl_hs_server_hello) {
1292     return false;
1293   }
1294 
1295   auto v = reinterpret_cast<std::vector<uint8_t>*>(arg);
1296   memcpy(data, &((*v)[0]), v->size());
1297   *len = v->size();
1298   return true;
1299 }
1300 
1301 // Note that |value| needs to be readable after this function returns.
InstallBadSrtpExtensionWriter(TransportTestPeer * peer,std::vector<uint8_t> * value)1302 static void InstallBadSrtpExtensionWriter(TransportTestPeer* peer,
1303                                           std::vector<uint8_t>* value) {
1304   peer->SetPostSetup([value](PRFileDesc* fd) {
1305     // Override the handler that is installed by the DTLS setup.
1306     SECStatus rv = SSL_InstallExtensionHooks(
1307         fd, ssl_use_srtp_xtn, WriteFixedXtn, value, NoopXtnHandler, nullptr);
1308     ASSERT_EQ(SECSuccess, rv);
1309   });
1310 }
1311 
TEST_F(TransportTest,TestSrtpErrorServerSendsTwoSrtpCiphers)1312 TEST_F(TransportTest, TestSrtpErrorServerSendsTwoSrtpCiphers) {
1313   // Server (p1_) sends an extension with two values, and empty MKI.
1314   std::vector<uint8_t> xtn = {0x04, 0x00, 0x01, 0x00, 0x02, 0x00};
1315   InstallBadSrtpExtensionWriter(p1_, &xtn);
1316   SetupSrtp();
1317   SetDtlsPeer();
1318   ConnectSocketExpectFail();
1319 }
1320 
TEST_F(TransportTest,TestSrtpErrorServerSendsTwoMki)1321 TEST_F(TransportTest, TestSrtpErrorServerSendsTwoMki) {
1322   // Server (p1_) sends an MKI.
1323   std::vector<uint8_t> xtn = {0x02, 0x00, 0x01, 0x01, 0x00};
1324   InstallBadSrtpExtensionWriter(p1_, &xtn);
1325   SetupSrtp();
1326   SetDtlsPeer();
1327   ConnectSocketExpectFail();
1328 }
1329 
TEST_F(TransportTest,TestSrtpErrorServerSendsUnknownValue)1330 TEST_F(TransportTest, TestSrtpErrorServerSendsUnknownValue) {
1331   std::vector<uint8_t> xtn = {0x02, 0x9a, 0xf1, 0x00};
1332   InstallBadSrtpExtensionWriter(p1_, &xtn);
1333   SetupSrtp();
1334   SetDtlsPeer();
1335   ConnectSocketExpectFail();
1336 }
1337 
TEST_F(TransportTest,TestSrtpErrorServerSendsOverflow)1338 TEST_F(TransportTest, TestSrtpErrorServerSendsOverflow) {
1339   std::vector<uint8_t> xtn = {0x32, 0x00, 0x01, 0x00};
1340   InstallBadSrtpExtensionWriter(p1_, &xtn);
1341   SetupSrtp();
1342   SetDtlsPeer();
1343   ConnectSocketExpectFail();
1344 }
1345 
TEST_F(TransportTest,TestSrtpErrorServerSendsUnevenList)1346 TEST_F(TransportTest, TestSrtpErrorServerSendsUnevenList) {
1347   std::vector<uint8_t> xtn = {0x01, 0x00, 0x00};
1348   InstallBadSrtpExtensionWriter(p1_, &xtn);
1349   SetupSrtp();
1350   SetDtlsPeer();
1351   ConnectSocketExpectFail();
1352 }
1353 
TEST_F(TransportTest,TestSrtpErrorClientSendsUnevenList)1354 TEST_F(TransportTest, TestSrtpErrorClientSendsUnevenList) {
1355   std::vector<uint8_t> xtn = {0x01, 0x00, 0x00};
1356   InstallBadSrtpExtensionWriter(p2_, &xtn);
1357   SetupSrtp();
1358   SetDtlsPeer();
1359   ConnectSocketExpectFail();
1360 }
1361 
TEST_F(TransportTest,OnlyServerSendsSrtpXtn)1362 TEST_F(TransportTest, OnlyServerSendsSrtpXtn) {
1363   p1_->SetupSrtp();
1364   SetDtlsPeer();
1365   // This should connect, but with no SRTP extension neogtiated.
1366   // The client side might negotiate a data channel only.
1367   ConnectSocket();
1368   ASSERT_NE(TLS_NULL_WITH_NULL_NULL, p1_->cipherSuite());
1369   ASSERT_EQ(0, p1_->srtpCipher());
1370 }
1371 
TEST_F(TransportTest,OnlyClientSendsSrtpXtn)1372 TEST_F(TransportTest, OnlyClientSendsSrtpXtn) {
1373   p2_->SetupSrtp();
1374   SetDtlsPeer();
1375   // This should connect, but with no SRTP extension neogtiated.
1376   // The server side might negotiate a data channel only.
1377   ConnectSocket();
1378   ASSERT_NE(TLS_NULL_WITH_NULL_NULL, p1_->cipherSuite());
1379   ASSERT_EQ(0, p1_->srtpCipher());
1380 }
1381 
1382 class TransportSrtpParameterTest
1383     : public TransportTest,
1384       public ::testing::WithParamInterface<uint16_t> {};
1385 
1386 INSTANTIATE_TEST_CASE_P(
1387     SrtpParamInit, TransportSrtpParameterTest,
1388     ::testing::ValuesIn(TransportLayerDtls::GetDefaultSrtpCiphers()));
1389 
TEST_P(TransportSrtpParameterTest,TestSrtpCiphersMismatchCombinations)1390 TEST_P(TransportSrtpParameterTest, TestSrtpCiphersMismatchCombinations) {
1391   uint16_t cipher = GetParam();
1392   std::cerr << "Checking cipher: " << cipher << std::endl;
1393 
1394   p1_->SetupSrtp();
1395 
1396   std::vector<uint16_t> setB;
1397   setB.push_back(cipher);
1398 
1399   p2_->SetSrtpCiphers(setB);
1400   SetDtlsPeer();
1401   ConnectSocket();
1402 
1403   ASSERT_EQ(cipher, p1_->srtpCipher());
1404   ASSERT_EQ(cipher, p2_->srtpCipher());
1405 }
1406 
1407 // NSS doesn't support DHE suites on the server end.
1408 // This checks to see if we barf when that's the only option available.
TEST_F(TransportTest,TestDheOnlyFails)1409 TEST_F(TransportTest, TestDheOnlyFails) {
1410   SetDtlsPeer();
1411 
1412   // p2_ is the client
1413   // setting this on p1_ (the server) causes NSS to assert
1414   ConfigureOneCipher(p2_, TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
1415   ConnectSocketExpectFail();
1416 }
1417 
1418 }  // end namespace
1419