1 /*
2  * crypto.c - Manage the global crypto
3  *
4  * Copyright (C) 2013 - 2019, Max Lv <max.c.lv@gmail.com>
5  *
6  * This file is part of the shadowsocks-libev.
7  *
8  * shadowsocks-libev is free software; you can redistribute it and/or modify
9  * it under the terms of the GNU General Public License as published by
10  * the Free Software Foundation; either version 3 of the License, or
11  * (at your option) any later version.
12  *
13  * shadowsocks-libev is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16  * GNU General Public License for more details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with shadowsocks-libev; see the file COPYING. If not, see
20  * <http://www.gnu.org/licenses/>.
21  */
22 
23 #ifdef HAVE_CONFIG_H
24 #include "config.h"
25 #endif
26 
27 #if defined(__linux__) && defined(HAVE_LINUX_RANDOM_H)
28 #include <fcntl.h>
29 #include <unistd.h>
30 #include <sys/ioctl.h>
31 #include <linux/random.h>
32 #endif
33 
34 #include <stdint.h>
35 #include <sodium.h>
36 #include <mbedtls/version.h>
37 #include <mbedtls/md5.h>
38 
39 #include "base64.h"
40 #include "crypto.h"
41 #include "stream.h"
42 #include "aead.h"
43 #include "utils.h"
44 #include "ppbloom.h"
45 
46 int
balloc(buffer_t * ptr,size_t capacity)47 balloc(buffer_t *ptr, size_t capacity)
48 {
49     sodium_memzero(ptr, sizeof(buffer_t));
50     ptr->data     = ss_malloc(capacity);
51     ptr->capacity = capacity;
52     return capacity;
53 }
54 
55 int
brealloc(buffer_t * ptr,size_t len,size_t capacity)56 brealloc(buffer_t *ptr, size_t len, size_t capacity)
57 {
58     if (ptr == NULL)
59         return -1;
60     size_t real_capacity = max(len, capacity);
61     if (ptr->capacity < real_capacity) {
62         ptr->data     = ss_realloc(ptr->data, real_capacity);
63         ptr->capacity = real_capacity;
64     }
65     return real_capacity;
66 }
67 
68 void
bfree(buffer_t * ptr)69 bfree(buffer_t *ptr)
70 {
71     if (ptr == NULL)
72         return;
73     ptr->idx      = 0;
74     ptr->len      = 0;
75     ptr->capacity = 0;
76     if (ptr->data != NULL) {
77         ss_free(ptr->data);
78     }
79 }
80 
81 int
bprepend(buffer_t * dst,buffer_t * src,size_t capacity)82 bprepend(buffer_t *dst, buffer_t *src, size_t capacity)
83 {
84     brealloc(dst, dst->len + src->len, capacity);
85     memmove(dst->data + src->len, dst->data, dst->len);
86     memcpy(dst->data, src->data, src->len);
87     dst->len = dst->len + src->len;
88     return dst->len;
89 }
90 
91 int
rand_bytes(void * output,int len)92 rand_bytes(void *output, int len)
93 {
94     randombytes_buf(output, len);
95     // always return success
96     return 0;
97 }
98 
99 unsigned char *
crypto_md5(const unsigned char * d,size_t n,unsigned char * md)100 crypto_md5(const unsigned char *d, size_t n, unsigned char *md)
101 {
102     static unsigned char m[16];
103     if (md == NULL) {
104         md = m;
105     }
106 #if MBEDTLS_VERSION_NUMBER >= 0x02070000
107     if (mbedtls_md5_ret(d, n, md) != 0)
108         FATAL("Failed to calculate MD5");
109 #else
110     mbedtls_md5(d, n, md);
111 #endif
112     return md;
113 }
114 
115 static void
entropy_check(void)116 entropy_check(void)
117 {
118 #if defined(__linux__) && defined(HAVE_LINUX_RANDOM_H) && defined(RNDGETENTCNT)
119     int fd;
120     int c;
121 
122     if ((fd = open("/dev/random", O_RDONLY)) != -1) {
123         if (ioctl(fd, RNDGETENTCNT, &c) == 0 && c < 160) {
124             LOGI("This system doesn't provide enough entropy to quickly generate high-quality random numbers.\n"
125                  "Installing the rng-utils/rng-tools, jitterentropy or haveged packages may help.\n"
126                  "On virtualized Linux environments, also consider using virtio-rng.\n"
127                  "The service will not start until enough entropy has been collected.\n");
128         }
129         close(fd);
130     }
131 #endif
132 }
133 
134 crypto_t *
crypto_init(const char * password,const char * key,const char * method)135 crypto_init(const char *password, const char *key, const char *method)
136 {
137     int i, m = -1;
138 
139     entropy_check();
140     // Initialize sodium for random generator
141     if (sodium_init() == -1) {
142         FATAL("Failed to initialize sodium");
143     }
144 
145     // Initialize NONCE bloom filter
146 #ifdef MODULE_REMOTE
147     ppbloom_init(BF_NUM_ENTRIES_FOR_SERVER, BF_ERROR_RATE_FOR_SERVER);
148 #else
149     ppbloom_init(BF_NUM_ENTRIES_FOR_CLIENT, BF_ERROR_RATE_FOR_CLIENT);
150 #endif
151 
152     if (method != NULL) {
153         for (i = 0; i < STREAM_CIPHER_NUM; i++)
154             if (strcmp(method, supported_stream_ciphers[i]) == 0) {
155                 m = i;
156                 break;
157             }
158         if (m != -1) {
159             LOGI("Stream ciphers are insecure, therefore deprecated, and should be almost always avoided.");
160             cipher_t *cipher = stream_init(password, key, method);
161             if (cipher == NULL)
162                 return NULL;
163             crypto_t *crypto = (crypto_t *)ss_malloc(sizeof(crypto_t));
164             crypto_t tmp     = {
165                 .cipher      = cipher,
166                 .encrypt_all = &stream_encrypt_all,
167                 .decrypt_all = &stream_decrypt_all,
168                 .encrypt     = &stream_encrypt,
169                 .decrypt     = &stream_decrypt,
170                 .ctx_init    = &stream_ctx_init,
171                 .ctx_release = &stream_ctx_release,
172             };
173             memcpy(crypto, &tmp, sizeof(crypto_t));
174             return crypto;
175         }
176 
177         for (i = 0; i < AEAD_CIPHER_NUM; i++)
178             if (strcmp(method, supported_aead_ciphers[i]) == 0) {
179                 m = i;
180                 break;
181             }
182         if (m != -1) {
183             cipher_t *cipher = aead_init(password, key, method);
184             if (cipher == NULL)
185                 return NULL;
186             crypto_t *crypto = (crypto_t *)ss_malloc(sizeof(crypto_t));
187             crypto_t tmp     = {
188                 .cipher      = cipher,
189                 .encrypt_all = &aead_encrypt_all,
190                 .decrypt_all = &aead_decrypt_all,
191                 .encrypt     = &aead_encrypt,
192                 .decrypt     = &aead_decrypt,
193                 .ctx_init    = &aead_ctx_init,
194                 .ctx_release = &aead_ctx_release,
195             };
196             memcpy(crypto, &tmp, sizeof(crypto_t));
197             return crypto;
198         }
199     }
200 
201     LOGE("invalid cipher name: %s", method);
202     return NULL;
203 }
204 
205 int
crypto_derive_key(const char * pass,uint8_t * key,size_t key_len)206 crypto_derive_key(const char *pass, uint8_t *key, size_t key_len)
207 {
208     size_t datal;
209     datal = strlen((const char *)pass);
210 
211     const digest_type_t *md = mbedtls_md_info_from_string("MD5");
212     if (md == NULL) {
213         FATAL("MD5 Digest not found in crypto library");
214     }
215 
216     mbedtls_md_context_t c;
217     unsigned char md_buf[MAX_MD_SIZE];
218     int addmd;
219     unsigned int i, j, mds;
220 
221     mds = mbedtls_md_get_size(md);
222     memset(&c, 0, sizeof(mbedtls_md_context_t));
223 
224     if (pass == NULL)
225         return key_len;
226     if (mbedtls_md_setup(&c, md, 0))
227         return 0;
228 
229     for (j = 0, addmd = 0; j < key_len; addmd++) {
230         mbedtls_md_starts(&c);
231         if (addmd) {
232             mbedtls_md_update(&c, md_buf, mds);
233         }
234         mbedtls_md_update(&c, (uint8_t *)pass, datal);
235         mbedtls_md_finish(&c, &(md_buf[0]));
236 
237         for (i = 0; i < mds; i++, j++) {
238             if (j >= key_len)
239                 break;
240             key[j] = md_buf[i];
241         }
242     }
243 
244     mbedtls_md_free(&c);
245     return key_len;
246 }
247 
248 /* HKDF-Extract + HKDF-Expand */
249 int
crypto_hkdf(const mbedtls_md_info_t * md,const unsigned char * salt,int salt_len,const unsigned char * ikm,int ikm_len,const unsigned char * info,int info_len,unsigned char * okm,int okm_len)250 crypto_hkdf(const mbedtls_md_info_t *md, const unsigned char *salt,
251             int salt_len, const unsigned char *ikm, int ikm_len,
252             const unsigned char *info, int info_len, unsigned char *okm,
253             int okm_len)
254 {
255     unsigned char prk[MBEDTLS_MD_MAX_SIZE];
256 
257     return crypto_hkdf_extract(md, salt, salt_len, ikm, ikm_len, prk) ||
258            crypto_hkdf_expand(md, prk, mbedtls_md_get_size(md), info, info_len,
259                               okm, okm_len);
260 }
261 
262 /* HKDF-Extract(salt, IKM) -> PRK */
263 int
crypto_hkdf_extract(const mbedtls_md_info_t * md,const unsigned char * salt,int salt_len,const unsigned char * ikm,int ikm_len,unsigned char * prk)264 crypto_hkdf_extract(const mbedtls_md_info_t *md, const unsigned char *salt,
265                     int salt_len, const unsigned char *ikm, int ikm_len,
266                     unsigned char *prk)
267 {
268     int hash_len;
269     unsigned char null_salt[MBEDTLS_MD_MAX_SIZE] = { '\0' };
270 
271     if (salt_len < 0) {
272         return CRYPTO_ERROR;
273     }
274 
275     hash_len = mbedtls_md_get_size(md);
276 
277     if (salt == NULL) {
278         salt     = null_salt;
279         salt_len = hash_len;
280     }
281 
282     return mbedtls_md_hmac(md, salt, salt_len, ikm, ikm_len, prk);
283 }
284 
285 /* HKDF-Expand(PRK, info, L) -> OKM */
286 int
crypto_hkdf_expand(const mbedtls_md_info_t * md,const unsigned char * prk,int prk_len,const unsigned char * info,int info_len,unsigned char * okm,int okm_len)287 crypto_hkdf_expand(const mbedtls_md_info_t *md, const unsigned char *prk,
288                    int prk_len, const unsigned char *info, int info_len,
289                    unsigned char *okm, int okm_len)
290 {
291     int hash_len;
292     int N;
293     int T_len = 0, where = 0, i, ret;
294     mbedtls_md_context_t ctx;
295     unsigned char T[MBEDTLS_MD_MAX_SIZE];
296 
297     if (info_len < 0 || okm_len < 0 || okm == NULL) {
298         return CRYPTO_ERROR;
299     }
300 
301     hash_len = mbedtls_md_get_size(md);
302 
303     if (prk_len < hash_len) {
304         return CRYPTO_ERROR;
305     }
306 
307     if (info == NULL) {
308         info = (const unsigned char *)"";
309     }
310 
311     N = okm_len / hash_len;
312 
313     if ((okm_len % hash_len) != 0) {
314         N++;
315     }
316 
317     if (N > 255) {
318         return CRYPTO_ERROR;
319     }
320 
321     mbedtls_md_init(&ctx);
322 
323     if ((ret = mbedtls_md_setup(&ctx, md, 1)) != 0) {
324         mbedtls_md_free(&ctx);
325         return ret;
326     }
327 
328     /* Section 2.3. */
329     for (i = 1; i <= N; i++) {
330         unsigned char c = i;
331 
332         ret = mbedtls_md_hmac_starts(&ctx, prk, prk_len) ||
333               mbedtls_md_hmac_update(&ctx, T, T_len) ||
334               mbedtls_md_hmac_update(&ctx, info, info_len) ||
335               /* The constant concatenated to the end of each T(n) is a single
336                * octet. */
337               mbedtls_md_hmac_update(&ctx, &c, 1) ||
338               mbedtls_md_hmac_finish(&ctx, T);
339 
340         if (ret != 0) {
341             mbedtls_md_free(&ctx);
342             return ret;
343         }
344 
345         memcpy(okm + where, T, (i != N) ? hash_len : (okm_len - where));
346         where += hash_len;
347         T_len  = hash_len;
348     }
349 
350     mbedtls_md_free(&ctx);
351 
352     return 0;
353 }
354 
355 int
crypto_parse_key(const char * base64,uint8_t * key,size_t key_len)356 crypto_parse_key(const char *base64, uint8_t *key, size_t key_len)
357 {
358     size_t base64_len = strlen(base64);
359     int out_len       = BASE64_SIZE(base64_len);
360     uint8_t out[out_len];
361 
362     out_len = base64_decode(out, base64, out_len);
363     if (out_len > 0 && out_len >= key_len) {
364         memcpy(key, out, key_len);
365 #ifdef SS_DEBUG
366         dump("KEY", (char *)key, key_len);
367 #endif
368         return key_len;
369     }
370 
371     out_len = BASE64_SIZE(key_len);
372     char out_key[out_len];
373     rand_bytes(key, key_len);
374     base64_encode(out_key, out_len, key, key_len);
375     LOGE("Invalid key for your chosen cipher!");
376     LOGE("It requires a " SIZE_FMT "-byte key encoded with URL-safe Base64", key_len);
377     LOGE("Generating a new random key: %s", out_key);
378     FATAL("Please use the key above or input a valid key");
379     return key_len;
380 }
381 
382 #ifdef SS_DEBUG
383 void
dump(char * tag,char * text,int len)384 dump(char *tag, char *text, int len)
385 {
386     int i;
387     printf("%s: ", tag);
388     for (i = 0; i < len; i++)
389         printf("0x%02x ", (uint8_t)text[i]);
390     printf("\n");
391 }
392 
393 #endif
394