1 /*
2  *  Copyright (c) 2018-present, Facebook, Inc.
3  *  All rights reserved.
4  *
5  *  This source code is licensed under the BSD-style license found in the
6  *  LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <fizz/crypto/aead/Aead.h>
12 #include <fizz/protocol/Params.h>
13 #include <fizz/record/Types.h>
14 #include <folly/Optional.h>
15 #include <folly/io/IOBufQueue.h>
16 
17 namespace fizz {
18 
19 struct TLSContent {
20   Buf data;
21   ContentType contentType;
22   EncryptionLevel encryptionLevel;
23 };
24 
25 /**
26  * RecordLayerState contains the state of the record layer -- all data
27  * that is needed in order to decrypt/encrypt the _next_ record from the wire/
28  * from the application.
29  */
30 struct RecordLayerState {
31   folly::Optional<TrafficKey> key;
32   folly::Optional<uint64_t> sequence;
33 };
34 
35 class ReadRecordLayer {
36  public:
37   template <class T>
38   struct ReadResult {
ReadResultReadResult39     ReadResult() : message(folly::none), sizeHint{0} {}
ReadResultReadResult40     /* implicit */ ReadResult(const folly::None&)
41         : message(folly::none), sizeHint{0} {}
42 
43     folly::Optional<T> message;
44 
45     // A non-zero size hint indicates the amount of bytes needed to continue
46     // record processing.
47     size_t sizeHint{0};
48 
fromReadResult49     static ReadResult from(T&& t) {
50       return from(std::forward<T>(t), 0);
51     }
52 
fromReadResult53     static ReadResult from(T&& t, size_t sizeHint) {
54       ReadResult r;
55       r.message = std::forward<T>(t);
56       r.sizeHint = sizeHint;
57       return r;
58     }
59 
noneReadResult60     static ReadResult none() {
61       return noneWithSizeHint(0);
62     }
63 
noneWithSizeHintReadResult64     static ReadResult noneWithSizeHint(size_t s) {
65       ReadResult r;
66       r.sizeHint = s;
67       return r;
68     }
69 
70     operator bool() const {
71       return bool(message);
72     }
73 
has_valueReadResult74     [[nodiscard]] bool has_value() const {
75       return message.has_value();
76     }
77 
78     auto operator->() -> decltype(this->message.operator->()) {
79       return message.operator->();
80     }
81 
82     auto operator->() const -> decltype(this->message.operator->()) {
83       return message.operator->();
84     }
85     auto operator*() -> decltype(this->message.operator*()) {
86       return message.operator*();
87     }
88   };
89 
90   virtual ~ReadRecordLayer() = default;
91 
92   /**
93    * Reads a fragment from the record layer. Returns an empty optional if
94    * insuficient data available. Throws if data malformed. On success, advances
95    * buf the amount read.
96    */
97   virtual ReadResult<TLSMessage> read(
98       folly::IOBufQueue& buf,
99       Aead::AeadOptions options) = 0;
100 
101   /**
102    * Get a message from the record layer. Returns none if insufficient data was
103    * available on the socket. Throws on parse error.
104    */
105   virtual ReadResult<Param> readEvent(
106       folly::IOBufQueue& socketBuf,
107       Aead::AeadOptions options);
108 
109   /**
110    * Check if there is decrypted but unparsed handshake data buffered.
111    */
112   virtual bool hasUnparsedHandshakeData() const;
113 
114   /**
115    * Returns the current encryption level of the data that the read record layer
116    * can process.
117    */
118   virtual EncryptionLevel getEncryptionLevel() const = 0;
119 
120   /**
121    * Returns a snapshot of the state of the record layer.
122    *
123    * `key`, if set, indicates the keying parameters for the AEAD associated
124    * with this ReadRecordLayer.
125    *
126    * `sequence`, if set, indicates the sequence number of the next expected
127    * record to be read.
128    */
getRecordLayerState()129   virtual RecordLayerState getRecordLayerState() const {
130     return RecordLayerState{};
131   }
132 
133   static folly::Optional<Param> decodeHandshakeMessage(folly::IOBufQueue& buf);
134 
135  private:
136   folly::IOBufQueue unparsedHandshakeData_{
137       folly::IOBufQueue::cacheChainLength()};
138 };
139 
140 class WriteRecordLayer {
141  public:
142   virtual ~WriteRecordLayer() = default;
143 
144   virtual TLSContent write(TLSMessage&& msg, Aead::AeadOptions options)
145       const = 0;
146 
writeAlert(Alert && alert)147   TLSContent writeAlert(Alert&& alert) const {
148     return write(
149         TLSMessage{ContentType::alert, encode(std::move(alert))},
150         Aead::AeadOptions());
151   }
152 
writeAppData(std::unique_ptr<folly::IOBuf> && appData,Aead::AeadOptions options)153   TLSContent writeAppData(
154       std::unique_ptr<folly::IOBuf>&& appData,
155       Aead::AeadOptions options) const {
156     return write(
157         TLSMessage{ContentType::application_data, std::move(appData)}, options);
158   }
159 
160   template <typename... Args>
writeHandshake(Buf && encodedHandshakeMsg,Args &&...args)161   TLSContent writeHandshake(Buf&& encodedHandshakeMsg, Args&&... args) const {
162     TLSMessage msg{ContentType::handshake, std::move(encodedHandshakeMsg)};
163     addMessage(msg.fragment, std::forward<Args>(args)...);
164     return write(std::move(msg), Aead::AeadOptions());
165   }
166 
setProtocolVersion(ProtocolVersion version)167   void setProtocolVersion(ProtocolVersion version) const {
168     auto realVersion = getRealDraftVersion(version);
169     if (realVersion == ProtocolVersion::tls_1_3_23) {
170       useAdditionalData_ = false;
171     } else {
172       useAdditionalData_ = true;
173     }
174   }
175 
176   /**
177    * Returns the current encryption level of the data that the write record
178    * layer writes at.
179    */
180   virtual EncryptionLevel getEncryptionLevel() const = 0;
181 
182   /**
183    * Returns a snapshot of the state of the record layer.
184    *
185    * `key`, if set, indicates the keying parameters for the AEAD associated
186    * with this WriteRecordLayer.
187    *
188    * `sequence`, if set, indicates the sequence number of the next expected
189    * record to be written.
190    */
getRecordLayerState()191   virtual RecordLayerState getRecordLayerState() const {
192     return RecordLayerState{};
193   }
194 
195  protected:
196   mutable bool useAdditionalData_{true};
197 
198  private:
199   template <typename... Args>
addMessage(Buf & buf,Buf && add,Args &&...args)200   static void addMessage(Buf& buf, Buf&& add, Args&&... args) {
201     buf->prependChain(std::move(add));
202     addMessage(buf, std::forward<Args>(args)...);
203   }
204 
addMessage(Buf &)205   static void addMessage(Buf& /*buf*/) {}
206 };
207 } // namespace fizz
208