1 #define PY_SSIZE_T_CLEAN
2 
3 #include <Python.h>
4 #include <openssl/err.h>
5 #include <openssl/evp.h>
6 
7 #define MODULE_NAME "aioquic._crypto"
8 
9 #define AEAD_KEY_LENGTH_MAX 32
10 #define AEAD_NONCE_LENGTH 12
11 #define AEAD_TAG_LENGTH 16
12 
13 #define PACKET_LENGTH_MAX 1500
14 #define PACKET_NUMBER_LENGTH_MAX 4
15 #define SAMPLE_LENGTH 16
16 
17 #define CHECK_RESULT(expr) \
18     if (!(expr)) { \
19         ERR_clear_error(); \
20         PyErr_SetString(CryptoError, "OpenSSL call failed"); \
21         return NULL; \
22     }
23 
24 #define CHECK_RESULT_CTOR(expr) \
25     if (!(expr)) { \
26         ERR_clear_error(); \
27         PyErr_SetString(CryptoError, "OpenSSL call failed"); \
28         return -1; \
29     }
30 
31 static PyObject *CryptoError;
32 
33 /* AEAD */
34 
35 typedef struct {
36     PyObject_HEAD
37     EVP_CIPHER_CTX *decrypt_ctx;
38     EVP_CIPHER_CTX *encrypt_ctx;
39     unsigned char buffer[PACKET_LENGTH_MAX];
40     unsigned char key[AEAD_KEY_LENGTH_MAX];
41     unsigned char iv[AEAD_NONCE_LENGTH];
42     unsigned char nonce[AEAD_NONCE_LENGTH];
43 } AEADObject;
44 
45 static EVP_CIPHER_CTX *
create_ctx(const EVP_CIPHER * cipher,int key_length,int operation)46 create_ctx(const EVP_CIPHER *cipher, int key_length, int operation)
47 {
48     EVP_CIPHER_CTX *ctx;
49     int res;
50 
51     ctx = EVP_CIPHER_CTX_new();
52     CHECK_RESULT(ctx != 0);
53 
54     res = EVP_CipherInit_ex(ctx, cipher, NULL, NULL, NULL, operation);
55     CHECK_RESULT(res != 0);
56 
57     res = EVP_CIPHER_CTX_set_key_length(ctx, key_length);
58     CHECK_RESULT(res != 0);
59 
60     res = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_CCM_SET_IVLEN, AEAD_NONCE_LENGTH, NULL);
61     CHECK_RESULT(res != 0);
62 
63     return ctx;
64 }
65 
66 static int
AEAD_init(AEADObject * self,PyObject * args,PyObject * kwargs)67 AEAD_init(AEADObject *self, PyObject *args, PyObject *kwargs)
68 {
69     const char *cipher_name;
70     const unsigned char *key, *iv;
71     Py_ssize_t cipher_name_len, key_len, iv_len;
72 
73     if (!PyArg_ParseTuple(args, "y#y#y#", &cipher_name, &cipher_name_len, &key, &key_len, &iv, &iv_len))
74         return -1;
75 
76     const EVP_CIPHER *evp_cipher = EVP_get_cipherbyname(cipher_name);
77     if (evp_cipher == 0) {
78         PyErr_Format(CryptoError, "Invalid cipher name: %s", cipher_name);
79         return -1;
80     }
81     if (key_len > AEAD_KEY_LENGTH_MAX) {
82         PyErr_SetString(CryptoError, "Invalid key length");
83         return -1;
84     }
85     if (iv_len > AEAD_NONCE_LENGTH) {
86         PyErr_SetString(CryptoError, "Invalid iv length");
87         return -1;
88     }
89 
90     memcpy(self->key, key, key_len);
91     memcpy(self->iv, iv, iv_len);
92 
93     self->decrypt_ctx = create_ctx(evp_cipher, key_len, 0);
94     CHECK_RESULT_CTOR(self->decrypt_ctx != 0);
95 
96     self->encrypt_ctx = create_ctx(evp_cipher, key_len, 1);
97     CHECK_RESULT_CTOR(self->encrypt_ctx != 0);
98 
99     return 0;
100 }
101 
102 static void
AEAD_dealloc(AEADObject * self)103 AEAD_dealloc(AEADObject *self)
104 {
105     EVP_CIPHER_CTX_free(self->decrypt_ctx);
106     EVP_CIPHER_CTX_free(self->encrypt_ctx);
107 }
108 
109 static PyObject*
AEAD_decrypt(AEADObject * self,PyObject * args)110 AEAD_decrypt(AEADObject *self, PyObject *args)
111 {
112     const unsigned char *data, *associated;
113     Py_ssize_t data_len, associated_len;
114     int outlen, outlen2, res;
115     uint64_t pn;
116 
117     if (!PyArg_ParseTuple(args, "y#y#K", &data, &data_len, &associated, &associated_len, &pn))
118         return NULL;
119 
120     if (data_len < AEAD_TAG_LENGTH || data_len > PACKET_LENGTH_MAX) {
121         PyErr_SetString(CryptoError, "Invalid payload length");
122         return NULL;
123     }
124 
125     memcpy(self->nonce, self->iv, AEAD_NONCE_LENGTH);
126     for (int i = 0; i < 8; ++i) {
127         self->nonce[AEAD_NONCE_LENGTH - 1 - i] ^= (uint8_t)(pn >> 8 * i);
128     }
129 
130     res = EVP_CIPHER_CTX_ctrl(self->decrypt_ctx, EVP_CTRL_CCM_SET_TAG, AEAD_TAG_LENGTH, (void*)(data + (data_len - AEAD_TAG_LENGTH)));
131     CHECK_RESULT(res != 0);
132 
133     res = EVP_CipherInit_ex(self->decrypt_ctx, NULL, NULL, self->key, self->nonce, 0);
134     CHECK_RESULT(res != 0);
135 
136     res = EVP_CipherUpdate(self->decrypt_ctx, NULL, &outlen, associated, associated_len);
137     CHECK_RESULT(res != 0);
138 
139     res = EVP_CipherUpdate(self->decrypt_ctx, self->buffer, &outlen, data, data_len - AEAD_TAG_LENGTH);
140     CHECK_RESULT(res != 0);
141 
142     res = EVP_CipherFinal_ex(self->decrypt_ctx, NULL, &outlen2);
143     if (res == 0) {
144         PyErr_SetString(CryptoError, "Payload decryption failed");
145         return NULL;
146     }
147 
148     return PyBytes_FromStringAndSize((const char*)self->buffer, outlen);
149 }
150 
151 static PyObject*
AEAD_encrypt(AEADObject * self,PyObject * args)152 AEAD_encrypt(AEADObject *self, PyObject *args)
153 {
154     const unsigned char *data, *associated;
155     Py_ssize_t data_len, associated_len;
156     int outlen, outlen2, res;
157     uint64_t pn;
158 
159     if (!PyArg_ParseTuple(args, "y#y#K", &data, &data_len, &associated, &associated_len, &pn))
160         return NULL;
161 
162     if (data_len > PACKET_LENGTH_MAX) {
163         PyErr_SetString(CryptoError, "Invalid payload length");
164         return NULL;
165     }
166 
167     memcpy(self->nonce, self->iv, AEAD_NONCE_LENGTH);
168     for (int i = 0; i < 8; ++i) {
169         self->nonce[AEAD_NONCE_LENGTH - 1 - i] ^= (uint8_t)(pn >> 8 * i);
170     }
171 
172     res = EVP_CipherInit_ex(self->encrypt_ctx, NULL, NULL, self->key, self->nonce, 1);
173     CHECK_RESULT(res != 0);
174 
175     res = EVP_CipherUpdate(self->encrypt_ctx, NULL, &outlen, associated, associated_len);
176     CHECK_RESULT(res != 0);
177 
178     res = EVP_CipherUpdate(self->encrypt_ctx, self->buffer, &outlen, data, data_len);
179     CHECK_RESULT(res != 0);
180 
181     res = EVP_CipherFinal_ex(self->encrypt_ctx, NULL, &outlen2);
182     CHECK_RESULT(res != 0 && outlen2 == 0);
183 
184     res = EVP_CIPHER_CTX_ctrl(self->encrypt_ctx, EVP_CTRL_CCM_GET_TAG, AEAD_TAG_LENGTH, self->buffer + outlen);
185     CHECK_RESULT(res != 0);
186 
187     return PyBytes_FromStringAndSize((const char*)self->buffer, outlen + AEAD_TAG_LENGTH);
188 }
189 
190 static PyMethodDef AEAD_methods[] = {
191     {"decrypt", (PyCFunction)AEAD_decrypt, METH_VARARGS, ""},
192     {"encrypt", (PyCFunction)AEAD_encrypt, METH_VARARGS, ""},
193 
194     {NULL}
195 };
196 
197 static PyTypeObject AEADType = {
198     PyVarObject_HEAD_INIT(NULL, 0)
199     MODULE_NAME ".AEAD",                /* tp_name */
200     sizeof(AEADObject),                 /* tp_basicsize */
201     0,                                  /* tp_itemsize */
202     (destructor)AEAD_dealloc,           /* tp_dealloc */
203     0,                                  /* tp_print */
204     0,                                  /* tp_getattr */
205     0,                                  /* tp_setattr */
206     0,                                  /* tp_reserved */
207     0,                                  /* tp_repr */
208     0,                                  /* tp_as_number */
209     0,                                  /* tp_as_sequence */
210     0,                                  /* tp_as_mapping */
211     0,                                  /* tp_hash  */
212     0,                                  /* tp_call */
213     0,                                  /* tp_str */
214     0,                                  /* tp_getattro */
215     0,                                  /* tp_setattro */
216     0,                                  /* tp_as_buffer */
217     Py_TPFLAGS_DEFAULT,                 /* tp_flags */
218     "AEAD objects",                     /* tp_doc */
219     0,                                  /* tp_traverse */
220     0,                                  /* tp_clear */
221     0,                                  /* tp_richcompare */
222     0,                                  /* tp_weaklistoffset */
223     0,                                  /* tp_iter */
224     0,                                  /* tp_iternext */
225     AEAD_methods,                       /* tp_methods */
226     0,                                  /* tp_members */
227     0,                                  /* tp_getset */
228     0,                                  /* tp_base */
229     0,                                  /* tp_dict */
230     0,                                  /* tp_descr_get */
231     0,                                  /* tp_descr_set */
232     0,                                  /* tp_dictoffset */
233     (initproc)AEAD_init,                /* tp_init */
234     0,                                  /* tp_alloc */
235 };
236 
237 /* HeaderProtection */
238 
239 typedef struct {
240     PyObject_HEAD
241     EVP_CIPHER_CTX *ctx;
242     int is_chacha20;
243     unsigned char buffer[PACKET_LENGTH_MAX];
244     unsigned char mask[31];
245     unsigned char zero[5];
246 } HeaderProtectionObject;
247 
248 static int
HeaderProtection_init(HeaderProtectionObject * self,PyObject * args,PyObject * kwargs)249 HeaderProtection_init(HeaderProtectionObject *self, PyObject *args, PyObject *kwargs)
250 {
251     const char *cipher_name;
252     const unsigned char *key;
253     Py_ssize_t cipher_name_len, key_len;
254     int res;
255 
256     if (!PyArg_ParseTuple(args, "y#y#", &cipher_name, &cipher_name_len, &key, &key_len))
257         return -1;
258 
259     const EVP_CIPHER *evp_cipher = EVP_get_cipherbyname(cipher_name);
260     if (evp_cipher == 0) {
261         PyErr_Format(CryptoError, "Invalid cipher name: %s", cipher_name);
262         return -1;
263     }
264 
265     memset(self->mask, 0, sizeof(self->mask));
266     memset(self->zero, 0, sizeof(self->zero));
267     self->is_chacha20 = cipher_name_len == 8 && memcmp(cipher_name, "chacha20", 8) == 0;
268 
269     self->ctx = EVP_CIPHER_CTX_new();
270     CHECK_RESULT_CTOR(self->ctx != 0);
271 
272     res = EVP_CipherInit_ex(self->ctx, evp_cipher, NULL, NULL, NULL, 1);
273     CHECK_RESULT_CTOR(res != 0);
274 
275     res = EVP_CIPHER_CTX_set_key_length(self->ctx, key_len);
276     CHECK_RESULT_CTOR(res != 0);
277 
278     res = EVP_CipherInit_ex(self->ctx, NULL, NULL, key, NULL, 1);
279     CHECK_RESULT_CTOR(res != 0);
280 
281     return 0;
282 }
283 
284 static void
HeaderProtection_dealloc(HeaderProtectionObject * self)285 HeaderProtection_dealloc(HeaderProtectionObject *self)
286 {
287     EVP_CIPHER_CTX_free(self->ctx);
288 }
289 
HeaderProtection_mask(HeaderProtectionObject * self,const unsigned char * sample)290 static int HeaderProtection_mask(HeaderProtectionObject *self, const unsigned char* sample)
291 {
292     int outlen;
293     if (self->is_chacha20) {
294         return EVP_CipherInit_ex(self->ctx, NULL, NULL, NULL, sample, 1) &&
295                EVP_CipherUpdate(self->ctx, self->mask, &outlen, self->zero, sizeof(self->zero));
296     } else {
297         return EVP_CipherUpdate(self->ctx, self->mask, &outlen, sample, SAMPLE_LENGTH);
298     }
299 }
300 
301 static PyObject*
HeaderProtection_apply(HeaderProtectionObject * self,PyObject * args)302 HeaderProtection_apply(HeaderProtectionObject *self, PyObject *args)
303 {
304     const unsigned char *header, *payload;
305     Py_ssize_t header_len, payload_len;
306     int res;
307 
308     if (!PyArg_ParseTuple(args, "y#y#", &header, &header_len, &payload, &payload_len))
309         return NULL;
310 
311     int pn_length = (header[0] & 0x03) + 1;
312     int pn_offset = header_len - pn_length;
313 
314     res = HeaderProtection_mask(self, payload + PACKET_NUMBER_LENGTH_MAX - pn_length);
315     CHECK_RESULT(res != 0);
316 
317     memcpy(self->buffer, header, header_len);
318     memcpy(self->buffer + header_len, payload, payload_len);
319 
320     if (self->buffer[0] & 0x80) {
321         self->buffer[0] ^= self->mask[0] & 0x0F;
322     } else {
323         self->buffer[0] ^= self->mask[0] & 0x1F;
324     }
325 
326     for (int i = 0; i < pn_length; ++i) {
327         self->buffer[pn_offset + i] ^= self->mask[1 + i];
328     }
329 
330     return PyBytes_FromStringAndSize((const char*)self->buffer, header_len + payload_len);
331 }
332 
333 static PyObject*
HeaderProtection_remove(HeaderProtectionObject * self,PyObject * args)334 HeaderProtection_remove(HeaderProtectionObject *self, PyObject *args)
335 {
336     const unsigned char *packet;
337     Py_ssize_t packet_len;
338     int pn_offset, res;
339 
340     if (!PyArg_ParseTuple(args, "y#I", &packet, &packet_len, &pn_offset))
341         return NULL;
342 
343     res = HeaderProtection_mask(self, packet + pn_offset + PACKET_NUMBER_LENGTH_MAX);
344     CHECK_RESULT(res != 0);
345 
346     memcpy(self->buffer, packet, pn_offset + PACKET_NUMBER_LENGTH_MAX);
347 
348     if (self->buffer[0] & 0x80) {
349         self->buffer[0] ^= self->mask[0] & 0x0F;
350     } else {
351         self->buffer[0] ^= self->mask[0] & 0x1F;
352     }
353 
354     int pn_length = (self->buffer[0] & 0x03) + 1;
355     uint32_t pn_truncated = 0;
356     for (int i = 0; i < pn_length; ++i) {
357         self->buffer[pn_offset + i] ^= self->mask[1 + i];
358         pn_truncated = self->buffer[pn_offset + i] | (pn_truncated << 8);
359     }
360 
361     return Py_BuildValue("y#i", self->buffer, pn_offset + pn_length, pn_truncated);
362 }
363 
364 static PyMethodDef HeaderProtection_methods[] = {
365     {"apply", (PyCFunction)HeaderProtection_apply, METH_VARARGS, ""},
366     {"remove", (PyCFunction)HeaderProtection_remove, METH_VARARGS, ""},
367     {NULL}
368 };
369 
370 static PyTypeObject HeaderProtectionType = {
371     PyVarObject_HEAD_INIT(NULL, 0)
372     MODULE_NAME ".HeaderProtection",    /* tp_name */
373     sizeof(HeaderProtectionObject),     /* tp_basicsize */
374     0,                                  /* tp_itemsize */
375     (destructor)HeaderProtection_dealloc,   /* tp_dealloc */
376     0,                                  /* tp_print */
377     0,                                  /* tp_getattr */
378     0,                                  /* tp_setattr */
379     0,                                  /* tp_reserved */
380     0,                                  /* tp_repr */
381     0,                                  /* tp_as_number */
382     0,                                  /* tp_as_sequence */
383     0,                                  /* tp_as_mapping */
384     0,                                  /* tp_hash  */
385     0,                                  /* tp_call */
386     0,                                  /* tp_str */
387     0,                                  /* tp_getattro */
388     0,                                  /* tp_setattro */
389     0,                                  /* tp_as_buffer */
390     Py_TPFLAGS_DEFAULT,                 /* tp_flags */
391     "HeaderProtection objects",         /* tp_doc */
392     0,                                  /* tp_traverse */
393     0,                                  /* tp_clear */
394     0,                                  /* tp_richcompare */
395     0,                                  /* tp_weaklistoffset */
396     0,                                  /* tp_iter */
397     0,                                  /* tp_iternext */
398     HeaderProtection_methods,           /* tp_methods */
399     0,                                  /* tp_members */
400     0,                                  /* tp_getset */
401     0,                                  /* tp_base */
402     0,                                  /* tp_dict */
403     0,                                  /* tp_descr_get */
404     0,                                  /* tp_descr_set */
405     0,                                  /* tp_dictoffset */
406     (initproc)HeaderProtection_init,    /* tp_init */
407     0,                                  /* tp_alloc */
408 };
409 
410 
411 static struct PyModuleDef moduledef = {
412     PyModuleDef_HEAD_INIT,
413     MODULE_NAME,                        /* m_name */
414     "A faster buffer.",                 /* m_doc */
415     -1,                                 /* m_size */
416     NULL,                               /* m_methods */
417     NULL,                               /* m_reload */
418     NULL,                               /* m_traverse */
419     NULL,                               /* m_clear */
420     NULL,                               /* m_free */
421 };
422 
423 PyMODINIT_FUNC
PyInit__crypto(void)424 PyInit__crypto(void)
425 {
426     PyObject* m;
427 
428     m = PyModule_Create(&moduledef);
429     if (m == NULL)
430         return NULL;
431 
432     CryptoError = PyErr_NewException(MODULE_NAME ".CryptoError", PyExc_ValueError, NULL);
433     Py_INCREF(CryptoError);
434     PyModule_AddObject(m, "CryptoError", CryptoError);
435 
436     AEADType.tp_new = PyType_GenericNew;
437     if (PyType_Ready(&AEADType) < 0)
438         return NULL;
439     Py_INCREF(&AEADType);
440     PyModule_AddObject(m, "AEAD", (PyObject *)&AEADType);
441 
442     HeaderProtectionType.tp_new = PyType_GenericNew;
443     if (PyType_Ready(&HeaderProtectionType) < 0)
444         return NULL;
445     Py_INCREF(&HeaderProtectionType);
446     PyModule_AddObject(m, "HeaderProtection", (PyObject *)&HeaderProtectionType);
447 
448     // ensure required ciphers are initialised
449     EVP_add_cipher(EVP_aes_128_ecb());
450     EVP_add_cipher(EVP_aes_128_gcm());
451     EVP_add_cipher(EVP_aes_256_ecb());
452     EVP_add_cipher(EVP_aes_256_gcm());
453 
454     return m;
455 }
456