1 /*
2  *  Copyright (c) 2018-present, Facebook, Inc.
3  *  All rights reserved.
4  *
5  *  This source code is licensed under the BSD-style license found in the
6  *  LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <folly/Optional.h>
12 #include <folly/io/IOBuf.h>
13 
14 namespace fizz {
15 
16 struct TrafficKey {
17   std::unique_ptr<folly::IOBuf> key;
18   std::unique_ptr<folly::IOBuf> iv;
19 
cloneTrafficKey20   TrafficKey clone() const {
21     return TrafficKey{key->clone(), iv->clone()};
22   }
23 };
24 
25 /**
26  * Interface for aead algorithms (RFC 5116).
27  */
28 class Aead {
29  public:
30   enum class BufferOption {
31     RespectSharedPolicy, // Assume shared = no in-place
32     AllowInPlace, // Assume in-place editing is safe
33     AllowFullModification, // Assume in-place editing and growing into
34                            // head/tailroom are safe.
35   };
36 
37   enum class AllocationOption {
38     Allow, // Allow allocating new buffers
39     Deny, // Disallow allocating new buffers
40   };
41 
42   struct AeadOptions {
43     BufferOption bufferOpt = BufferOption::RespectSharedPolicy;
44     AllocationOption allocOpt = AllocationOption::Allow;
45   };
46 
47   virtual ~Aead() = default;
48 
49   /**
50    * Returns the number of key bytes needed by this aead.
51    */
52   virtual size_t keyLength() const = 0;
53 
54   /**
55    * Returns the number of iv bytes needed by this aead.
56    */
57   virtual size_t ivLength() const = 0;
58 
59   /**
60    * Sets the key and iv for this aead. The length of the key and iv must match
61    * keyLength() and ivLength().
62    */
63   virtual void setKey(TrafficKey key) = 0;
64 
65   /**
66    * Retrieves a shallow copy (IOBuf cloned) version of the TrafficKey
67    * corresponding to this AEAD, if set. Otherwise, returns none.
68    */
69   virtual folly::Optional<TrafficKey> getKey() const = 0;
70 
71   /**
72    * Encrypts plaintext. Will throw on error.
73    *
74    * Uses BufferOption::RespectSharedPolicy and AllocationOption::Allow by
75    * default.
76    */
encrypt(std::unique_ptr<folly::IOBuf> && plaintext,const folly::IOBuf * associatedData,uint64_t seqNum)77   std::unique_ptr<folly::IOBuf> encrypt(
78       std::unique_ptr<folly::IOBuf>&& plaintext,
79       const folly::IOBuf* associatedData,
80       uint64_t seqNum) const {
81     return encrypt(
82         std::forward<std::unique_ptr<folly::IOBuf>>(plaintext),
83         associatedData,
84         seqNum,
85         {BufferOption::RespectSharedPolicy, AllocationOption::Allow});
86   }
87 
88   virtual std::unique_ptr<folly::IOBuf> encrypt(
89       std::unique_ptr<folly::IOBuf>&& plaintext,
90       const folly::IOBuf* associatedData,
91       uint64_t seqNum,
92       AeadOptions options) const = 0;
93 
94   /**
95    * Version of encrypt which is guaranteed to be inplace. Will throw an
96    * exception if the inplace encryption cannot be done.
97    *
98    * Equivalent of calling encrypt() with BufferOption::AllowFullModification
99    * and AllocationOption::Deny.
100    */
101   virtual std::unique_ptr<folly::IOBuf> inplaceEncrypt(
102       std::unique_ptr<folly::IOBuf>&& plaintext,
103       const folly::IOBuf* associatedData,
104       uint64_t seqNum) const = 0;
105 
106   /**
107    * Set a hint to the AEAD about how much space to try to leave as headroom for
108    * ciphertexts returned from encrypt.  Implementations may or may not honor
109    * this.
110    */
111   virtual void setEncryptedBufferHeadroom(size_t headroom) = 0;
112 
113   /**
114    * Decrypt ciphertext. Will throw if the ciphertext does not decrypt
115    * successfully.
116    *
117    * Uses BufferOption::RespectSharedPolicy and AllocationOption::Allow by
118    * default.
119    */
decrypt(std::unique_ptr<folly::IOBuf> && ciphertext,const folly::IOBuf * associatedData,uint64_t seqNum)120   std::unique_ptr<folly::IOBuf> decrypt(
121       std::unique_ptr<folly::IOBuf>&& ciphertext,
122       const folly::IOBuf* associatedData,
123       uint64_t seqNum) const {
124     return decrypt(
125         std::forward<std::unique_ptr<folly::IOBuf>>(ciphertext),
126         associatedData,
127         seqNum,
128         {BufferOption::RespectSharedPolicy, AllocationOption::Allow});
129   }
130 
decrypt(std::unique_ptr<folly::IOBuf> && ciphertext,const folly::IOBuf * associatedData,uint64_t seqNum,AeadOptions options)131   virtual std::unique_ptr<folly::IOBuf> decrypt(
132       std::unique_ptr<folly::IOBuf>&& ciphertext,
133       const folly::IOBuf* associatedData,
134       uint64_t seqNum,
135       AeadOptions options) const {
136     auto plaintext = tryDecrypt(
137         std::forward<std::unique_ptr<folly::IOBuf>>(ciphertext),
138         associatedData,
139         seqNum,
140         options);
141     if (!plaintext) {
142       throw std::runtime_error("decryption failed");
143     }
144     return std::move(*plaintext);
145   }
146 
147   /**
148    * Decrypt ciphertext. Will return none if the ciphertext does not decrypt
149    * successfully. May still throw from errors unrelated to ciphertext.
150    *
151    * Uses BufferOption::RespectSharedPolicy and AllocationOption::Allow by
152    * default.
153    */
tryDecrypt(std::unique_ptr<folly::IOBuf> && ciphertext,const folly::IOBuf * associatedData,uint64_t seqNum)154   folly::Optional<std::unique_ptr<folly::IOBuf>> tryDecrypt(
155       std::unique_ptr<folly::IOBuf>&& ciphertext,
156       const folly::IOBuf* associatedData,
157       uint64_t seqNum) const {
158     return tryDecrypt(
159         std::forward<std::unique_ptr<folly::IOBuf>>(ciphertext),
160         associatedData,
161         seqNum,
162         {BufferOption::RespectSharedPolicy, AllocationOption::Allow});
163   }
164 
165   virtual folly::Optional<std::unique_ptr<folly::IOBuf>> tryDecrypt(
166       std::unique_ptr<folly::IOBuf>&& ciphertext,
167       const folly::IOBuf* associatedData,
168       uint64_t seqNum,
169       AeadOptions options) const = 0;
170 
171   /**
172    * Returns the number of bytes the aead will add to the plaintext (size of
173    * ciphertext - size of plaintext).
174    */
175   virtual size_t getCipherOverhead() const = 0;
176 };
177 } // namespace fizz
178