1 /* Copyright 2015-2016 OpenMarket Ltd
2  *
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #include "olm/message.hh"
16 
17 #include "olm/memory.hh"
18 
19 namespace {
20 
21 template<typename T>
varint_length(T value)22 static std::size_t varint_length(
23     T value
24 ) {
25     std::size_t result = 1;
26     while (value >= 128U) {
27         ++result;
28         value >>= 7;
29     }
30     return result;
31 }
32 
33 
34 template<typename T>
varint_encode(std::uint8_t * output,T value)35 static std::uint8_t * varint_encode(
36     std::uint8_t * output,
37     T value
38 ) {
39     while (value >= 128U) {
40         *(output++) = (0x7F & value) | 0x80;
41         value >>= 7;
42     }
43     (*output++) = value;
44     return output;
45 }
46 
47 
48 template<typename T>
varint_decode(std::uint8_t const * varint_start,std::uint8_t const * varint_end)49 static T varint_decode(
50     std::uint8_t const * varint_start,
51     std::uint8_t const * varint_end
52 ) {
53     T value = 0;
54     if (varint_end == varint_start) {
55         return 0;
56     }
57     do {
58         value <<= 7;
59         value |= 0x7F & *(--varint_end);
60     } while (varint_end != varint_start);
61     return value;
62 }
63 
64 
varint_skip(std::uint8_t const * input,std::uint8_t const * input_end)65 static std::uint8_t const * varint_skip(
66     std::uint8_t const * input,
67     std::uint8_t const * input_end
68 ) {
69     while (input != input_end) {
70         std::uint8_t tmp = *(input++);
71         if ((tmp & 0x80) == 0) {
72             return input;
73         }
74     }
75     return input;
76 }
77 
78 
varstring_length(std::size_t string_length)79 static std::size_t varstring_length(
80     std::size_t string_length
81 ) {
82     return varint_length(string_length) + string_length;
83 }
84 
85 static std::size_t const VERSION_LENGTH = 1;
86 static std::uint8_t const RATCHET_KEY_TAG = 012;
87 static std::uint8_t const COUNTER_TAG = 020;
88 static std::uint8_t const CIPHERTEXT_TAG = 042;
89 
encode(std::uint8_t * pos,std::uint8_t tag,std::uint32_t value)90 static std::uint8_t * encode(
91     std::uint8_t * pos,
92     std::uint8_t tag,
93     std::uint32_t value
94 ) {
95     *(pos++) = tag;
96     return varint_encode(pos, value);
97 }
98 
encode(std::uint8_t * pos,std::uint8_t tag,std::uint8_t * & value,std::size_t value_length)99 static std::uint8_t * encode(
100     std::uint8_t * pos,
101     std::uint8_t tag,
102     std::uint8_t * & value, std::size_t value_length
103 ) {
104     *(pos++) = tag;
105     pos = varint_encode(pos, value_length);
106     value = pos;
107     return pos + value_length;
108 }
109 
decode(std::uint8_t const * pos,std::uint8_t const * end,std::uint8_t tag,std::uint32_t & value,bool & has_value)110 static std::uint8_t const * decode(
111     std::uint8_t const * pos, std::uint8_t const * end,
112     std::uint8_t tag,
113     std::uint32_t & value, bool & has_value
114 ) {
115     if (pos != end && *pos == tag) {
116         ++pos;
117         std::uint8_t const * value_start = pos;
118         pos = varint_skip(pos, end);
119         value = varint_decode<std::uint32_t>(value_start, pos);
120         has_value = true;
121     }
122     return pos;
123 }
124 
125 
decode(std::uint8_t const * pos,std::uint8_t const * end,std::uint8_t tag,std::uint8_t const * & value,std::size_t & value_length)126 static std::uint8_t const * decode(
127     std::uint8_t const * pos, std::uint8_t const * end,
128     std::uint8_t tag,
129     std::uint8_t const * & value, std::size_t & value_length
130 ) {
131     if (pos != end && *pos == tag) {
132         ++pos;
133         std::uint8_t const * len_start = pos;
134         pos = varint_skip(pos, end);
135         std::size_t len = varint_decode<std::size_t>(len_start, pos);
136         if (len > std::size_t(end - pos)) return end;
137         value = pos;
138         value_length = len;
139         pos += len;
140     }
141     return pos;
142 }
143 
skip_unknown(std::uint8_t const * pos,std::uint8_t const * end)144 static std::uint8_t const * skip_unknown(
145     std::uint8_t const * pos, std::uint8_t const * end
146 ) {
147     if (pos != end) {
148         uint8_t tag = *pos;
149         if ((tag & 0x7) == 0) {
150             pos = varint_skip(pos, end);
151             pos = varint_skip(pos, end);
152         } else if ((tag & 0x7) == 2) {
153             pos = varint_skip(pos, end);
154             std::uint8_t const * len_start = pos;
155             pos = varint_skip(pos, end);
156             std::size_t len = varint_decode<std::size_t>(len_start, pos);
157             if (len > std::size_t(end - pos)) return end;
158             pos += len;
159         } else {
160             return end;
161         }
162     }
163     return pos;
164 }
165 
166 } // namespace
167 
168 
encode_message_length(std::uint32_t counter,std::size_t ratchet_key_length,std::size_t ciphertext_length,std::size_t mac_length)169 std::size_t olm::encode_message_length(
170     std::uint32_t counter,
171     std::size_t ratchet_key_length,
172     std::size_t ciphertext_length,
173     std::size_t mac_length
174 ) {
175     std::size_t length = VERSION_LENGTH;
176     length += 1 + varstring_length(ratchet_key_length);
177     length += 1 + varint_length(counter);
178     length += 1 + varstring_length(ciphertext_length);
179     length += mac_length;
180     return length;
181 }
182 
183 
encode_message(olm::MessageWriter & writer,std::uint8_t version,std::uint32_t counter,std::size_t ratchet_key_length,std::size_t ciphertext_length,std::uint8_t * output)184 void olm::encode_message(
185     olm::MessageWriter & writer,
186     std::uint8_t version,
187     std::uint32_t counter,
188     std::size_t ratchet_key_length,
189     std::size_t ciphertext_length,
190     std::uint8_t * output
191 ) {
192     std::uint8_t * pos = output;
193     *(pos++) = version;
194     pos = encode(pos, RATCHET_KEY_TAG, writer.ratchet_key, ratchet_key_length);
195     pos = encode(pos, COUNTER_TAG, counter);
196     pos = encode(pos, CIPHERTEXT_TAG, writer.ciphertext, ciphertext_length);
197 }
198 
199 
decode_message(olm::MessageReader & reader,std::uint8_t const * input,std::size_t input_length,std::size_t mac_length)200 void olm::decode_message(
201     olm::MessageReader & reader,
202     std::uint8_t const * input, std::size_t input_length,
203     std::size_t mac_length
204 ) {
205     std::uint8_t const * pos = input;
206     std::uint8_t const * end = input + input_length - mac_length;
207     std::uint8_t const * unknown = nullptr;
208 
209     reader.version = 0;
210     reader.has_counter = false;
211     reader.counter = 0;
212     reader.input = input;
213     reader.input_length = input_length;
214     reader.ratchet_key = nullptr;
215     reader.ratchet_key_length = 0;
216     reader.ciphertext = nullptr;
217     reader.ciphertext_length = 0;
218 
219     if (input_length < mac_length) return;
220 
221     if (pos == end) return;
222     reader.version = *(pos++);
223 
224     while (pos != end) {
225         unknown = pos;
226         pos = decode(
227             pos, end, RATCHET_KEY_TAG,
228             reader.ratchet_key, reader.ratchet_key_length
229         );
230         pos = decode(
231             pos, end, COUNTER_TAG,
232             reader.counter, reader.has_counter
233         );
234         pos = decode(
235             pos, end, CIPHERTEXT_TAG,
236             reader.ciphertext, reader.ciphertext_length
237         );
238         if (unknown == pos) {
239             pos = skip_unknown(pos, end);
240         }
241     }
242 }
243 
244 
245 namespace {
246 
247 static std::uint8_t const ONE_TIME_KEY_ID_TAG = 012;
248 static std::uint8_t const BASE_KEY_TAG = 022;
249 static std::uint8_t const IDENTITY_KEY_TAG = 032;
250 static std::uint8_t const MESSAGE_TAG = 042;
251 
252 } // namespace
253 
254 
encode_one_time_key_message_length(std::size_t one_time_key_length,std::size_t identity_key_length,std::size_t base_key_length,std::size_t message_length)255 std::size_t olm::encode_one_time_key_message_length(
256     std::size_t one_time_key_length,
257     std::size_t identity_key_length,
258     std::size_t base_key_length,
259     std::size_t message_length
260 ) {
261     std::size_t length = VERSION_LENGTH;
262     length += 1 + varstring_length(one_time_key_length);
263     length += 1 + varstring_length(identity_key_length);
264     length += 1 + varstring_length(base_key_length);
265     length += 1 + varstring_length(message_length);
266     return length;
267 }
268 
269 
encode_one_time_key_message(olm::PreKeyMessageWriter & writer,std::uint8_t version,std::size_t identity_key_length,std::size_t base_key_length,std::size_t one_time_key_length,std::size_t message_length,std::uint8_t * output)270 void olm::encode_one_time_key_message(
271     olm::PreKeyMessageWriter & writer,
272     std::uint8_t version,
273     std::size_t identity_key_length,
274     std::size_t base_key_length,
275     std::size_t one_time_key_length,
276     std::size_t message_length,
277     std::uint8_t * output
278 ) {
279     std::uint8_t * pos = output;
280     *(pos++) = version;
281     pos = encode(pos, ONE_TIME_KEY_ID_TAG, writer.one_time_key, one_time_key_length);
282     pos = encode(pos, BASE_KEY_TAG, writer.base_key, base_key_length);
283     pos = encode(pos, IDENTITY_KEY_TAG, writer.identity_key, identity_key_length);
284     pos = encode(pos, MESSAGE_TAG, writer.message, message_length);
285 }
286 
287 
decode_one_time_key_message(PreKeyMessageReader & reader,std::uint8_t const * input,std::size_t input_length)288 void olm::decode_one_time_key_message(
289     PreKeyMessageReader & reader,
290     std::uint8_t const * input, std::size_t input_length
291 ) {
292     std::uint8_t const * pos = input;
293     std::uint8_t const * end = input + input_length;
294     std::uint8_t const * unknown = nullptr;
295 
296     reader.version = 0;
297     reader.one_time_key = nullptr;
298     reader.one_time_key_length = 0;
299     reader.identity_key = nullptr;
300     reader.identity_key_length = 0;
301     reader.base_key = nullptr;
302     reader.base_key_length = 0;
303     reader.message = nullptr;
304     reader.message_length = 0;
305 
306     if (pos == end) return;
307     reader.version = *(pos++);
308 
309     while (pos != end) {
310         unknown = pos;
311         pos = decode(
312             pos, end, ONE_TIME_KEY_ID_TAG,
313             reader.one_time_key, reader.one_time_key_length
314         );
315         pos = decode(
316             pos, end, BASE_KEY_TAG,
317             reader.base_key, reader.base_key_length
318         );
319         pos = decode(
320             pos, end, IDENTITY_KEY_TAG,
321             reader.identity_key, reader.identity_key_length
322         );
323         pos = decode(
324             pos, end, MESSAGE_TAG,
325             reader.message, reader.message_length
326         );
327         if (unknown == pos) {
328             pos = skip_unknown(pos, end);
329         }
330     }
331 }
332 
333 
334 
335 static const std::uint8_t GROUP_MESSAGE_INDEX_TAG = 010;
336 static const std::uint8_t GROUP_CIPHERTEXT_TAG = 022;
337 
_olm_encode_group_message_length(uint32_t message_index,size_t ciphertext_length,size_t mac_length,size_t signature_length)338 size_t _olm_encode_group_message_length(
339     uint32_t message_index,
340     size_t ciphertext_length,
341     size_t mac_length,
342     size_t signature_length
343 ) {
344     size_t length = VERSION_LENGTH;
345     length += 1 + varint_length(message_index);
346     length += 1 + varstring_length(ciphertext_length);
347     length += mac_length;
348     length += signature_length;
349     return length;
350 }
351 
352 
_olm_encode_group_message(uint8_t version,uint32_t message_index,size_t ciphertext_length,uint8_t * output,uint8_t ** ciphertext_ptr)353 size_t _olm_encode_group_message(
354     uint8_t version,
355     uint32_t message_index,
356     size_t ciphertext_length,
357     uint8_t *output,
358     uint8_t **ciphertext_ptr
359 ) {
360     std::uint8_t * pos = output;
361 
362     *(pos++) = version;
363     pos = encode(pos, GROUP_MESSAGE_INDEX_TAG, message_index);
364     pos = encode(pos, GROUP_CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length);
365     return pos-output;
366 }
367 
_olm_decode_group_message(const uint8_t * input,size_t input_length,size_t mac_length,size_t signature_length,struct _OlmDecodeGroupMessageResults * results)368 void _olm_decode_group_message(
369     const uint8_t *input, size_t input_length,
370     size_t mac_length, size_t signature_length,
371     struct _OlmDecodeGroupMessageResults *results
372 ) {
373     std::uint8_t const * pos = input;
374     std::size_t trailer_length = mac_length + signature_length;
375     std::uint8_t const * end = input + input_length - trailer_length;
376     std::uint8_t const * unknown = nullptr;
377 
378     bool has_message_index = false;
379     results->version = 0;
380     results->message_index = 0;
381     results->has_message_index = (int)has_message_index;
382     results->ciphertext = nullptr;
383     results->ciphertext_length = 0;
384 
385     if (input_length < trailer_length) return;
386 
387     if (pos == end) return;
388     results->version = *(pos++);
389 
390     while (pos != end) {
391         unknown = pos;
392         pos = decode(
393             pos, end, GROUP_MESSAGE_INDEX_TAG,
394             results->message_index, has_message_index
395         );
396         pos = decode(
397             pos, end, GROUP_CIPHERTEXT_TAG,
398             results->ciphertext, results->ciphertext_length
399         );
400         if (unknown == pos) {
401             pos = skip_unknown(pos, end);
402         }
403     }
404 
405     results->has_message_index = (int)has_message_index;
406 }
407