1 use crate::kem::Kem as KemTrait;
2 
3 use byteorder::{BigEndian, ByteOrder};
4 use digest::{BlockInput, Digest, FixedOutput, Reset, Update};
5 use generic_array::GenericArray;
6 use sha2::{Sha256, Sha384, Sha512};
7 
8 const VERSION_LABEL: &[u8] = b"HPKE-v1";
9 
10 // This is currently the maximum value of Nh. It is achieved by HKDF-SHA512.
11 pub(crate) const MAX_DIGEST_SIZE: usize = 512;
12 
13 // Pretty much all the KDF functionality is covered by the hkdf crate
14 
15 /// Represents key derivation functionality
16 pub trait Kdf {
17     /// The underlying hash function
18     #[doc(hidden)]
19     type HashImpl: Digest + Update + BlockInput + FixedOutput + Reset + Default + Clone;
20 
21     /// The algorithm identifier for a KDF implementation
22     const KDF_ID: u16;
23 }
24 
25 // We use Kdf as a type parameter, so this is to avoid ambiguity.
26 use Kdf as KdfTrait;
27 
28 /// The implementation of HKDF-SHA256
29 pub struct HkdfSha256 {}
30 
31 // The KDF_ID constant below come from §7.2
32 
33 impl KdfTrait for HkdfSha256 {
34     #[doc(hidden)]
35     type HashImpl = Sha256;
36 
37     const KDF_ID: u16 = 0x0001;
38 }
39 
40 /// The implementation of HKDF-SHA384
41 pub struct HkdfSha384 {}
42 
43 impl KdfTrait for HkdfSha384 {
44     #[doc(hidden)]
45     type HashImpl = Sha384;
46 
47     const KDF_ID: u16 = 0x0002;
48 }
49 
50 /// The implementation of HKDF-SHA512
51 pub struct HkdfSha512 {}
52 
53 impl KdfTrait for HkdfSha512 {
54     #[doc(hidden)]
55     type HashImpl = Sha512;
56 
57     const KDF_ID: u16 = 0x0003;
58 }
59 
60 // def ExtractAndExpand(dh, kemContext):
61 //   eae_prk = LabeledExtract(zero(0), "eae_prk", dh)
62 //   shared_secret = LabeledExpand(eae_prk, "shared_secret", kemContext, Nsecret)
63 //   return shared_secret
64 /// Uses the given IKM to extract a secret, and then uses that secret, plus the given suite ID and
65 /// info string, to expand to the output buffer
extract_and_expand<Kem: KemTrait>( ikm: &[u8], suite_id: &[u8], info: &[u8], out: &mut [u8], ) -> Result<(), hkdf::InvalidLength>66 pub(crate) fn extract_and_expand<Kem: KemTrait>(
67     ikm: &[u8],
68     suite_id: &[u8],
69     info: &[u8],
70     out: &mut [u8],
71 ) -> Result<(), hkdf::InvalidLength> {
72     // Construct the labels
73     // Extract using given IKM
74     let (_, hkdf_ctx) = labeled_extract::<Kem::Kdf>(&[], suite_id, b"eae_prk", ikm);
75     // Expand using given info string
76     hkdf_ctx.labeled_expand(suite_id, b"shared_secret", info, out)
77 }
78 
79 // def LabeledExtract(salt, label, ikm):
80 //   labeled_ikm = concat("HPKE-05 ", suite_id, label, ikm)
81 //   return Extract(salt, labeled_ikm)
82 /// Returns the HKDF context derived from `(salt=salt, ikm="HPKE-05 "||suite_id||label||ikm)`
labeled_extract<Kdf: KdfTrait>( salt: &[u8], suite_id: &[u8], label: &[u8], ikm: &[u8], ) -> ( GenericArray<u8, <<Kdf as KdfTrait>::HashImpl as FixedOutput>::OutputSize>, hkdf::Hkdf<Kdf::HashImpl>, )83 pub(crate) fn labeled_extract<Kdf: KdfTrait>(
84     salt: &[u8],
85     suite_id: &[u8],
86     label: &[u8],
87     ikm: &[u8],
88 ) -> (
89     GenericArray<u8, <<Kdf as KdfTrait>::HashImpl as FixedOutput>::OutputSize>,
90     hkdf::Hkdf<Kdf::HashImpl>,
91 ) {
92     // Call HKDF-Extract with the IKM being the concatenation of all of the above
93     let mut extract_ctx = hkdf::HkdfExtract::<Kdf::HashImpl>::new(Some(&salt));
94     extract_ctx.input_ikm(VERSION_LABEL);
95     extract_ctx.input_ikm(suite_id);
96     extract_ctx.input_ikm(label);
97     extract_ctx.input_ikm(ikm);
98     extract_ctx.finalize()
99 }
100 
101 // This trait only exists so I can implement it for hkdf::Hkdf
102 pub(crate) trait LabeledExpand {
labeled_expand( &self, suite_id: &[u8], label: &[u8], info: &[u8], out: &mut [u8], ) -> Result<(), hkdf::InvalidLength>103     fn labeled_expand(
104         &self,
105         suite_id: &[u8],
106         label: &[u8],
107         info: &[u8],
108         out: &mut [u8],
109     ) -> Result<(), hkdf::InvalidLength>;
110 }
111 
112 impl<D: Update + BlockInput + FixedOutput + Reset + Default + Clone> LabeledExpand
113     for hkdf::Hkdf<D>
114 {
115     // def LabeledExpand(prk, label, info, L):
116     //   labeled_info = concat(I2OSP(L, 2), "HPKE-05 ", suite_id, label, info)
117     //   return Expand(prk, labeled_info, L)
labeled_expand( &self, suite_id: &[u8], label: &[u8], info: &[u8], out: &mut [u8], ) -> Result<(), hkdf::InvalidLength>118     fn labeled_expand(
119         &self,
120         suite_id: &[u8],
121         label: &[u8],
122         info: &[u8],
123         out: &mut [u8],
124     ) -> Result<(), hkdf::InvalidLength> {
125         // We need to write the length as a u16, so that's the de-facto upper bound on length
126         assert!(out.len() <= u16::MAX as usize);
127 
128         // Encode the output length in the info string
129         let mut len_buf = [0u8; 2];
130         BigEndian::write_u16(&mut len_buf, out.len() as u16);
131 
132         // Call HKDF-Expand() with the info string set to the concatenation of all of the above
133         let labeled_info = [&len_buf, VERSION_LABEL, suite_id, label, info];
134         self.expand_multi_info(&labeled_info, out)
135     }
136 }
137