1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this file,
5  * You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 // Original author: ekr@rtfm.com
8 
9 #ifndef transportlayerdtls_h__
10 #define transportlayerdtls_h__
11 
12 #include <queue>
13 #include <set>
14 
15 #include "sigslot.h"
16 
17 #include "mozilla/RefPtr.h"
18 #include "mozilla/UniquePtr.h"
19 #include "nsCOMPtr.h"
20 #include "nsIEventTarget.h"
21 #include "nsITimer.h"
22 #include "ScopedNSSTypes.h"
23 #include "m_cpp_utils.h"
24 #include "dtlsidentity.h"
25 #include "transportflow.h"
26 #include "transportlayer.h"
27 
28 namespace mozilla {
29 
30 struct Packet;
31 
32 class TransportLayerNSPRAdapter {
33  public:
TransportLayerNSPRAdapter(TransportLayer * output)34   explicit TransportLayerNSPRAdapter(TransportLayer *output) :
35   output_(output),
36   input_(),
37   enabled_(true) {}
38 
39   void PacketReceived(const void *data, int32_t len);
40   int32_t Recv(void *buf, int32_t buflen);
41   int32_t Write(const void *buf, int32_t length);
SetEnabled(bool enabled)42   void SetEnabled(bool enabled) { enabled_ = enabled; }
43 
44  private:
45   DISALLOW_COPY_ASSIGN(TransportLayerNSPRAdapter);
46 
47   TransportLayer *output_;
48   std::queue<Packet *> input_;
49   bool enabled_;
50 };
51 
52 class TransportLayerDtls final : public TransportLayer {
53  public:
TransportLayerDtls()54   TransportLayerDtls() :
55       role_(CLIENT),
56       verification_mode_(VERIFY_UNSET),
57       ssl_fd_(nullptr),
58       auth_hook_called_(false),
59       cert_ok_(false) {}
60 
61   virtual ~TransportLayerDtls();
62 
63   enum Role { CLIENT, SERVER};
64   enum Verification { VERIFY_UNSET, VERIFY_ALLOW_ALL, VERIFY_DIGEST};
65   const static size_t kMaxDigestLength = HASH_LENGTH_MAX;
66 
67   // DTLS-specific operations
SetRole(Role role)68   void SetRole(Role role) { role_ = role;}
role()69   Role role() { return role_; }
70 
SetIdentity(const RefPtr<DtlsIdentity> & identity)71   void SetIdentity(const RefPtr<DtlsIdentity>& identity) {
72     identity_ = identity;
73   }
74   nsresult SetAlpn(const std::set<std::string>& allowedAlpn,
75                    const std::string& alpnDefault);
GetNegotiatedAlpn()76   const std::string& GetNegotiatedAlpn() const { return alpn_; }
77 
78   nsresult SetVerificationAllowAll();
79   nsresult SetVerificationDigest(const std::string digest_algorithm,
80                                  const unsigned char *digest_value,
81                                  size_t digest_len);
82 
83   nsresult GetCipherSuite(uint16_t* cipherSuite) const;
84 
85   nsresult SetSrtpCiphers(std::vector<uint16_t> ciphers);
86   nsresult GetSrtpCipher(uint16_t *cipher) const;
87 
88   nsresult ExportKeyingMaterial(const std::string& label,
89                                 bool use_context,
90                                 const std::string& context,
91                                 unsigned char *out,
92                                 unsigned int outlen);
93 
94   // Transport layer overrides.
95   virtual nsresult InitInternal();
96   virtual void WasInserted();
97   virtual TransportResult SendPacket(const unsigned char *data, size_t len);
98 
99   // Signals
100   void StateChange(TransportLayer *layer, State state);
101   void PacketReceived(TransportLayer* layer, const unsigned char *data,
102                       size_t len);
103 
104   // For testing use only.  Returns the fd.
internal_fd()105   PRFileDesc* internal_fd() { CheckThread(); return ssl_fd_.get(); }
106 
107   TRANSPORT_LAYER_ID("dtls")
108 
109   private:
110   DISALLOW_COPY_ASSIGN(TransportLayerDtls);
111 
112   // A single digest to check
113   class VerificationDigest {
114    public:
VerificationDigest(std::string algorithm,const unsigned char * value,size_t len)115     VerificationDigest(std::string algorithm,
116                        const unsigned char *value, size_t len) {
117       MOZ_ASSERT(len <= sizeof(value_));
118 
119       algorithm_ = algorithm;
120       memcpy(value_, value, len);
121       len_ = len;
122     }
123 
124     NS_INLINE_DECL_THREADSAFE_REFCOUNTING(VerificationDigest)
125 
126     std::string algorithm_;
127     size_t len_;
128     unsigned char value_[kMaxDigestLength];
129 
130    private:
~VerificationDigest()131     ~VerificationDigest() {}
132     DISALLOW_COPY_ASSIGN(VerificationDigest);
133   };
134 
135 
136   bool Setup();
137   bool SetupCipherSuites(UniquePRFileDesc& ssl_fd) const;
138   bool SetupAlpn(UniquePRFileDesc& ssl_fd) const;
139   void Handshake();
140 
141   bool CheckAlpn();
142 
143   static SECStatus GetClientAuthDataHook(void *arg, PRFileDesc *fd,
144                                          CERTDistNames *caNames,
145                                          CERTCertificate **pRetCert,
146                                          SECKEYPrivateKey **pRetKey);
147   static SECStatus AuthCertificateHook(void *arg,
148                                        PRFileDesc *fd,
149                                        PRBool checksig,
150                                        PRBool isServer);
151   SECStatus AuthCertificateHook(PRFileDesc *fd,
152                                 PRBool checksig,
153                                 PRBool isServer);
154 
155   static void TimerCallback(nsITimer *timer, void *arg);
156 
157   SECStatus CheckDigest(const RefPtr<VerificationDigest>& digest,
158                         UniqueCERTCertificate& cert) const;
159 
160   RefPtr<DtlsIdentity> identity_;
161   // What ALPN identifiers are permitted.
162   std::set<std::string> alpn_allowed_;
163   // What ALPN identifier is used if ALPN is not supported.
164   // The empty string indicates that ALPN is required.
165   std::string alpn_default_;
166   // What ALPN string was negotiated.
167   std::string alpn_;
168   std::vector<uint16_t> srtp_ciphers_;
169 
170   Role role_;
171   Verification verification_mode_;
172   std::vector<RefPtr<VerificationDigest> > digests_;
173 
174   // Must delete nspr_io_adapter after ssl_fd_ b/c ssl_fd_ causes an alert
175   // (ssl_fd_ contains an un-owning pointer to nspr_io_adapter_)
176   UniquePtr<TransportLayerNSPRAdapter> nspr_io_adapter_;
177   UniquePRFileDesc ssl_fd_;
178 
179   nsCOMPtr<nsITimer> timer_;
180   bool auth_hook_called_;
181   bool cert_ok_;
182 };
183 
184 
185 }  // close namespace
186 #endif
187