1 /**
2  * @file
3  * Application layered TCP/TLS connection API (to be used from TCPIP thread)
4  *
5  * This file provides a TLS layer using mbedTLS
6  *
7  * This version is currently compatible with the 2.x.x branch (current LTS).
8  */
9 
10 /*
11  * Copyright (c) 2017 Simon Goldschmidt
12  * All rights reserved.
13  *
14  * Redistribution and use in source and binary forms, with or without modification,
15  * are permitted provided that the following conditions are met:
16  *
17  * 1. Redistributions of source code must retain the above copyright notice,
18  *    this list of conditions and the following disclaimer.
19  * 2. Redistributions in binary form must reproduce the above copyright notice,
20  *    this list of conditions and the following disclaimer in the documentation
21  *    and/or other materials provided with the distribution.
22  * 3. The name of the author may not be used to endorse or promote products
23  *    derived from this software without specific prior written permission.
24  *
25  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
26  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
27  * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
28  * SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
29  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
30  * OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
31  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
32  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
33  * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
34  * OF SUCH DAMAGE.
35  *
36  * This file is part of the lwIP TCP/IP stack.
37  *
38  * Author: Simon Goldschmidt <goldsimon@gmx.de>
39  *
40  * Watch out:
41  * - 'sent' is always called with len==0 to the upper layer. This is because keeping
42  *   track of the ratio of application data and TLS overhead would be too much.
43  *
44  * Mandatory security-related configuration:
45  * - ensure to add at least one strong entropy source to your mbedtls port (implement
46  *   mbedtls_platform_entropy_poll or mbedtls_hardware_poll providing strong entropy)
47  * - define ALTCP_MBEDTLS_ENTROPY_PTR and ALTCP_MBEDTLS_ENTROPY_LEN to something providing
48  *   GOOD custom entropy
49  *
50  * Missing things / @todo:
51  * - some unhandled/untested things migh be caught by LWIP_ASSERTs...
52  */
53 
54 #include "lwip/opt.h"
55 #include "lwip/sys.h"
56 
57 #if LWIP_ALTCP /* don't build if not configured for use in lwipopts.h */
58 
59 #include "lwip/apps/altcp_tls_mbedtls_opts.h"
60 
61 #if LWIP_ALTCP_TLS && LWIP_ALTCP_TLS_MBEDTLS
62 
63 #include "lwip/altcp.h"
64 #include "lwip/altcp_tls.h"
65 #include "lwip/priv/altcp_priv.h"
66 
67 #include "altcp_tls_mbedtls_structs.h"
68 #include "altcp_tls_mbedtls_mem.h"
69 
70 /* @todo: which includes are really needed? */
71 #include "mbedtls/entropy.h"
72 #include "mbedtls/ctr_drbg.h"
73 #include "mbedtls/certs.h"
74 #include "mbedtls/x509.h"
75 #include "mbedtls/ssl.h"
76 #include "mbedtls/net_sockets.h"
77 #include "mbedtls/error.h"
78 #include "mbedtls/debug.h"
79 #include "mbedtls/platform.h"
80 #include "mbedtls/memory_buffer_alloc.h"
81 #include "mbedtls/ssl_cache.h"
82 #include "mbedtls/ssl_ticket.h"
83 
84 #include "mbedtls/ssl_internal.h" /* to call mbedtls_flush_output after ERR_MEM */
85 
86 #include <string.h>
87 
88 #ifndef ALTCP_MBEDTLS_ENTROPY_PTR
89 #define ALTCP_MBEDTLS_ENTROPY_PTR   NULL
90 #endif
91 #ifndef ALTCP_MBEDTLS_ENTROPY_LEN
92 #define ALTCP_MBEDTLS_ENTROPY_LEN   0
93 #endif
94 #ifndef ALTCP_MBEDTLS_RNG_FN
95 #define ALTCP_MBEDTLS_RNG_FN   mbedtls_entropy_func
96 #endif
97 
98 /* Variable prototype, the actual declaration is at the end of this file
99    since it contains pointers to static functions declared here */
100 extern const struct altcp_functions altcp_mbedtls_functions;
101 
102 /** Our global mbedTLS configuration (server-specific, not connection-specific) */
103 struct altcp_tls_config {
104   mbedtls_ssl_config conf;
105   mbedtls_x509_crt *cert;
106   mbedtls_pk_context *pkey;
107   u8_t cert_count;
108   u8_t cert_max;
109   u8_t pkey_count;
110   u8_t pkey_max;
111   mbedtls_x509_crt *ca;
112 #if defined(MBEDTLS_SSL_CACHE_C) && ALTCP_MBEDTLS_USE_SESSION_CACHE
113   /** Inter-connection cache for fast connection startup */
114   struct mbedtls_ssl_cache_context cache;
115 #endif
116 #if defined(MBEDTLS_SSL_SESSION_TICKETS) && ALTCP_MBEDTLS_USE_SESSION_TICKETS
117   mbedtls_ssl_ticket_context ticket_ctx;
118 #endif
119 };
120 
121 /** Entropy and random generator are shared by all mbedTLS configuration */
122 struct altcp_tls_entropy_rng {
123   mbedtls_entropy_context entropy;
124   mbedtls_ctr_drbg_context ctr_drbg;
125   int ref;
126 };
127 static struct altcp_tls_entropy_rng *altcp_tls_entropy_rng;
128 
129 static err_t altcp_mbedtls_lower_recv(void *arg, struct altcp_pcb *inner_conn, struct pbuf *p, err_t err);
130 static err_t altcp_mbedtls_setup(void *conf, struct altcp_pcb *conn, struct altcp_pcb *inner_conn);
131 static err_t altcp_mbedtls_lower_recv_process(struct altcp_pcb *conn, altcp_mbedtls_state_t *state);
132 static err_t altcp_mbedtls_handle_rx_appldata(struct altcp_pcb *conn, altcp_mbedtls_state_t *state);
133 static int altcp_mbedtls_bio_send(void *ctx, const unsigned char *dataptr, size_t size);
134 
135 
136 /* callback functions from inner/lower connection: */
137 
138 /** Accept callback from lower connection (i.e. TCP)
139  * Allocate one of our structures, assign it to the new connection's 'state' and
140  * call the new connection's 'accepted' callback. If that succeeds, we wait
141  * to receive connection setup handshake bytes from the client.
142  */
143 static err_t
altcp_mbedtls_lower_accept(void * arg,struct altcp_pcb * accepted_conn,err_t err)144 altcp_mbedtls_lower_accept(void *arg, struct altcp_pcb *accepted_conn, err_t err)
145 {
146   struct altcp_pcb *listen_conn = (struct altcp_pcb *)arg;
147   if (listen_conn && listen_conn->state && listen_conn->accept) {
148     err_t setup_err;
149     altcp_mbedtls_state_t *listen_state = (altcp_mbedtls_state_t *)listen_conn->state;
150     /* create a new altcp_conn to pass to the next 'accept' callback */
151     struct altcp_pcb *new_conn = altcp_alloc();
152     if (new_conn == NULL) {
153       return ERR_MEM;
154     }
155     setup_err = altcp_mbedtls_setup(listen_state->conf, new_conn, accepted_conn);
156     if (setup_err != ERR_OK) {
157       altcp_free(new_conn);
158       return setup_err;
159     }
160     return listen_conn->accept(listen_conn->arg, new_conn, err);
161   }
162   return ERR_ARG;
163 }
164 
165 /** Connected callback from lower connection (i.e. TCP).
166  * Not really implemented/tested yet...
167  */
168 static err_t
altcp_mbedtls_lower_connected(void * arg,struct altcp_pcb * inner_conn,err_t err)169 altcp_mbedtls_lower_connected(void *arg, struct altcp_pcb *inner_conn, err_t err)
170 {
171   struct altcp_pcb *conn = (struct altcp_pcb *)arg;
172   LWIP_UNUSED_ARG(inner_conn); /* for LWIP_NOASSERT */
173   if (conn && conn->state) {
174     altcp_mbedtls_state_t *state;
175     LWIP_ASSERT("pcb mismatch", conn->inner_conn == inner_conn);
176     /* upper connected is called when handshake is done */
177     if (err != ERR_OK) {
178       if (conn->connected) {
179         return conn->connected(conn->arg, conn, err);
180       }
181     }
182     state = (altcp_mbedtls_state_t *)conn->state;
183     /* ensure overhead value is valid before first write */
184     state->overhead_bytes_adjust = 0;
185     return altcp_mbedtls_lower_recv_process(conn, state);
186   }
187   return ERR_VAL;
188 }
189 
190 /* Call recved for possibly more than an u16_t */
191 static void
altcp_mbedtls_lower_recved(struct altcp_pcb * inner_conn,int recvd_cnt)192 altcp_mbedtls_lower_recved(struct altcp_pcb *inner_conn, int recvd_cnt)
193 {
194   while (recvd_cnt > 0) {
195     u16_t recvd_part = (u16_t)LWIP_MIN(recvd_cnt, 0xFFFF);
196     altcp_recved(inner_conn, recvd_part);
197     recvd_cnt -= recvd_part;
198   }
199 }
200 
201 /** Recv callback from lower connection (i.e. TCP)
202  * This one mainly differs between connection setup/handshake (data is fed into mbedTLS only)
203  * and application phase (data is decoded by mbedTLS and passed on to the application).
204  */
205 static err_t
altcp_mbedtls_lower_recv(void * arg,struct altcp_pcb * inner_conn,struct pbuf * p,err_t err)206 altcp_mbedtls_lower_recv(void *arg, struct altcp_pcb *inner_conn, struct pbuf *p, err_t err)
207 {
208   altcp_mbedtls_state_t *state;
209   struct altcp_pcb *conn = (struct altcp_pcb *)arg;
210 
211   LWIP_ASSERT("no err expected", err == ERR_OK);
212   LWIP_UNUSED_ARG(err);
213 
214   if (!conn) {
215     /* no connection given as arg? should not happen, but prevent pbuf/conn leaks */
216     if (p != NULL) {
217       pbuf_free(p);
218     }
219     altcp_close(inner_conn);
220     return ERR_CLSD;
221   }
222   state = (altcp_mbedtls_state_t *)conn->state;
223   LWIP_ASSERT("pcb mismatch", conn->inner_conn == inner_conn);
224   if (!state) {
225     /* already closed */
226     if (p != NULL) {
227       pbuf_free(p);
228     }
229     altcp_close(inner_conn);
230     return ERR_CLSD;
231   }
232 
233   /* handle NULL pbuf (inner connection closed) */
234   if (p == NULL) {
235     /* remote host sent FIN, remember this (SSL state is destroyed
236         when both sides are closed only!) */
237     if ((state->flags & (ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE | ALTCP_MBEDTLS_FLAGS_UPPER_CALLED)) ==
238         (ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE | ALTCP_MBEDTLS_FLAGS_UPPER_CALLED)) {
239       /* need to notify upper layer (e.g. 'accept' called or 'connect' succeeded) */
240       if ((state->rx != NULL) || (state->rx_app != NULL)) {
241         state->flags |= ALTCP_MBEDTLS_FLAGS_RX_CLOSE_QUEUED;
242         /* this is a normal close (FIN) but we have unprocessed data, so delay the FIN */
243         altcp_mbedtls_handle_rx_appldata(conn, state);
244         return ERR_OK;
245       }
246       state->flags |= ALTCP_MBEDTLS_FLAGS_RX_CLOSED;
247       if (conn->recv) {
248         return conn->recv(conn->arg, conn, NULL, ERR_OK);
249       }
250     } else {
251       /* before connection setup is done: call 'err' */
252       if (conn->err) {
253         conn->err(conn->arg, ERR_ABRT);
254       }
255       altcp_close(conn);
256     }
257     return ERR_OK;
258   }
259 
260   /* If we come here, the connection is in good state (handshake phase or application data phase).
261      Queue up the pbuf for processing as handshake data or application data. */
262   if (state->rx == NULL) {
263     state->rx = p;
264   } else {
265     LWIP_ASSERT("rx pbuf overflow", (int)p->tot_len + (int)p->len <= 0xFFFF);
266     pbuf_cat(state->rx, p);
267   }
268   return altcp_mbedtls_lower_recv_process(conn, state);
269 }
270 
271 static err_t
altcp_mbedtls_lower_recv_process(struct altcp_pcb * conn,altcp_mbedtls_state_t * state)272 altcp_mbedtls_lower_recv_process(struct altcp_pcb *conn, altcp_mbedtls_state_t *state)
273 {
274   if (!(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) {
275     /* handle connection setup (handshake not done) */
276     int ret = mbedtls_ssl_handshake(&state->ssl_context);
277     /* try to send data... */
278     altcp_output(conn->inner_conn);
279     if (state->bio_bytes_read) {
280       /* acknowledge all bytes read */
281       altcp_mbedtls_lower_recved(conn->inner_conn, state->bio_bytes_read);
282       state->bio_bytes_read = 0;
283     }
284 
285     if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
286       /* handshake not done, wait for more recv calls */
287       LWIP_ASSERT("in this state, the rx chain should be empty", state->rx == NULL);
288       return ERR_OK;
289     }
290     if (ret != 0) {
291       LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("mbedtls_ssl_handshake failed: %d\n", ret));
292       /* handshake failed, connection has to be closed */
293       if (conn->err) {
294         conn->err(conn->arg, ERR_CLSD);
295       }
296 
297       if (altcp_close(conn) != ERR_OK) {
298         altcp_abort(conn);
299       }
300       return ERR_OK;
301     }
302     /* If we come here, handshake succeeded. */
303     LWIP_ASSERT("state", state->bio_bytes_read == 0);
304     LWIP_ASSERT("state", state->bio_bytes_appl == 0);
305     state->flags |= ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE;
306     /* issue "connect" callback" to upper connection (this can only happen for active open) */
307     if (conn->connected) {
308       err_t err;
309       err = conn->connected(conn->arg, conn, ERR_OK);
310       if (err != ERR_OK) {
311         return err;
312       }
313     }
314     if (state->rx == NULL) {
315       return ERR_OK;
316     }
317   }
318   /* handle application data */
319   return altcp_mbedtls_handle_rx_appldata(conn, state);
320 }
321 
322 /* Pass queued decoded rx data to application */
323 static err_t
altcp_mbedtls_pass_rx_data(struct altcp_pcb * conn,altcp_mbedtls_state_t * state)324 altcp_mbedtls_pass_rx_data(struct altcp_pcb *conn, altcp_mbedtls_state_t *state)
325 {
326   err_t err;
327   struct pbuf *buf;
328   LWIP_ASSERT("conn != NULL", conn != NULL);
329   LWIP_ASSERT("state != NULL", state != NULL);
330   buf = state->rx_app;
331   if (buf) {
332     state->rx_app = NULL;
333     if (conn->recv) {
334       u16_t tot_len = buf->tot_len;
335       /* this needs to be increased first because the 'recved' call may come nested */
336       state->rx_passed_unrecved += tot_len;
337       state->flags |= ALTCP_MBEDTLS_FLAGS_UPPER_CALLED;
338       err = conn->recv(conn->arg, conn, buf, ERR_OK);
339       if (err != ERR_OK) {
340         if (err == ERR_ABRT) {
341           return ERR_ABRT;
342         }
343         /* not received, leave the pbuf(s) queued (and decrease 'unrecved' again) */
344         LWIP_ASSERT("state == conn->state", state == conn->state);
345         state->rx_app = buf;
346         state->rx_passed_unrecved -= tot_len;
347         LWIP_ASSERT("state->rx_passed_unrecved >= 0", state->rx_passed_unrecved >= 0);
348         if (state->rx_passed_unrecved < 0) {
349           state->rx_passed_unrecved = 0;
350         }
351         return err;
352       }
353     } else {
354       pbuf_free(buf);
355     }
356   } else if ((state->flags & (ALTCP_MBEDTLS_FLAGS_RX_CLOSE_QUEUED | ALTCP_MBEDTLS_FLAGS_RX_CLOSED)) ==
357              ALTCP_MBEDTLS_FLAGS_RX_CLOSE_QUEUED) {
358     state->flags |= ALTCP_MBEDTLS_FLAGS_RX_CLOSED;
359     if (conn->recv) {
360       return conn->recv(conn->arg, conn, NULL, ERR_OK);
361     }
362   }
363 
364   /* application may have close the connection */
365   if (conn->state != state) {
366     /* return error code to ensure altcp_mbedtls_handle_rx_appldata() exits the loop */
367     return ERR_ARG;
368   }
369   return ERR_OK;
370 }
371 
372 /* Helper function that processes rx application data stored in rx pbuf chain */
373 static err_t
altcp_mbedtls_handle_rx_appldata(struct altcp_pcb * conn,altcp_mbedtls_state_t * state)374 altcp_mbedtls_handle_rx_appldata(struct altcp_pcb *conn, altcp_mbedtls_state_t *state)
375 {
376   int ret;
377   LWIP_ASSERT("state != NULL", state != NULL);
378   if (!(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) {
379     /* handshake not done yet */
380     return ERR_VAL;
381   }
382   do {
383     /* allocate a full-sized unchained PBUF_POOL: this is for RX! */
384     struct pbuf *buf = pbuf_alloc(PBUF_RAW, PBUF_POOL_BUFSIZE, PBUF_POOL);
385     if (buf == NULL) {
386       /* We're short on pbufs, try again later from 'poll' or 'recv' callbacks.
387          @todo: close on excessive allocation failures or leave this up to upper conn? */
388       return ERR_OK;
389     }
390 
391     /* decrypt application data, this pulls encrypted RX data off state->rx pbuf chain */
392     ret = mbedtls_ssl_read(&state->ssl_context, (unsigned char *)buf->payload, PBUF_POOL_BUFSIZE);
393     if (ret < 0) {
394       if (ret == MBEDTLS_ERR_SSL_CLIENT_RECONNECT) {
395         /* client is initiating a new connection using the same source port -> close connection or make handshake */
396         LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("new connection on same source port\n"));
397         LWIP_ASSERT("TODO: new connection on same source port, close this connection", 0);
398       } else if ((ret != MBEDTLS_ERR_SSL_WANT_READ) && (ret != MBEDTLS_ERR_SSL_WANT_WRITE)) {
399         if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
400           LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("connection was closed gracefully\n"));
401         } else if (ret == MBEDTLS_ERR_NET_CONN_RESET) {
402           LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("connection was reset by peer\n"));
403         }
404         pbuf_free(buf);
405         return ERR_OK;
406       } else {
407         pbuf_free(buf);
408         return ERR_OK;
409       }
410       pbuf_free(buf);
411       altcp_abort(conn);
412       return ERR_ABRT;
413     } else {
414       err_t err;
415       if (ret) {
416         LWIP_ASSERT("bogus receive length", ret <= PBUF_POOL_BUFSIZE);
417         /* trim pool pbuf to actually decoded length */
418         pbuf_realloc(buf, (u16_t)ret);
419 
420         state->bio_bytes_appl += ret;
421         if (mbedtls_ssl_get_bytes_avail(&state->ssl_context) == 0) {
422           /* Record is done, now we know the share between application and protocol bytes
423              and can adjust the RX window by the protocol bytes.
424              The rest is 'recved' by the application calling our 'recved' fn. */
425           int overhead_bytes;
426           LWIP_ASSERT("bogus byte counts", state->bio_bytes_read > state->bio_bytes_appl);
427           overhead_bytes = state->bio_bytes_read - state->bio_bytes_appl;
428           altcp_mbedtls_lower_recved(conn->inner_conn, overhead_bytes);
429           state->bio_bytes_read = 0;
430           state->bio_bytes_appl = 0;
431         }
432 
433         if (state->rx_app == NULL) {
434           state->rx_app = buf;
435         } else {
436           pbuf_cat(state->rx_app, buf);
437         }
438       } else {
439         pbuf_free(buf);
440         buf = NULL;
441       }
442       err = altcp_mbedtls_pass_rx_data(conn, state);
443       if (err != ERR_OK) {
444         if (err == ERR_ABRT) {
445           /* recv callback needs to return this as the pcb is deallocated */
446           return ERR_ABRT;
447         }
448         /* we hide all other errors as we retry feeding the pbuf to the app later */
449         return ERR_OK;
450       }
451     }
452   } while (ret > 0);
453   return ERR_OK;
454 }
455 
456 /** Receive callback function called from mbedtls (set via mbedtls_ssl_set_bio)
457  * This function mainly copies data from pbufs and frees the pbufs after copying.
458  */
459 static int
altcp_mbedtls_bio_recv(void * ctx,unsigned char * buf,size_t len)460 altcp_mbedtls_bio_recv(void *ctx, unsigned char *buf, size_t len)
461 {
462   struct altcp_pcb *conn = (struct altcp_pcb *)ctx;
463   altcp_mbedtls_state_t *state;
464   struct pbuf *p;
465   u16_t ret;
466   u16_t copy_len;
467   err_t err;
468 
469   LWIP_UNUSED_ARG(err); /* for LWIP_NOASSERT */
470   if ((conn == NULL) || (conn->state == NULL)) {
471     return MBEDTLS_ERR_NET_INVALID_CONTEXT;
472   }
473   state = (altcp_mbedtls_state_t *)conn->state;
474   LWIP_ASSERT("state != NULL", state != NULL);
475   p = state->rx;
476 
477   /* @todo: return MBEDTLS_ERR_NET_CONN_RESET/MBEDTLS_ERR_NET_RECV_FAILED? */
478 
479   if ((p == NULL) || ((p->len == 0) && (p->next == NULL))) {
480     if (p) {
481       pbuf_free(p);
482     }
483     state->rx = NULL;
484     if ((state->flags & (ALTCP_MBEDTLS_FLAGS_RX_CLOSE_QUEUED | ALTCP_MBEDTLS_FLAGS_RX_CLOSED)) ==
485         ALTCP_MBEDTLS_FLAGS_RX_CLOSE_QUEUED) {
486       /* close queued but not passed up yet */
487       return 0;
488     }
489     return MBEDTLS_ERR_SSL_WANT_READ;
490   }
491   /* limit number of bytes again to copy from first pbuf in a chain only */
492   copy_len = (u16_t)LWIP_MIN(len, p->len);
493   /* copy the data */
494   ret = pbuf_copy_partial(p, buf, copy_len, 0);
495   LWIP_ASSERT("ret == copy_len", ret == copy_len);
496   /* hide the copied bytes from the pbuf */
497   err = pbuf_remove_header(p, ret);
498   LWIP_ASSERT("error", err == ERR_OK);
499   if (p->len == 0) {
500     /* the first pbuf has been fully read, free it */
501     state->rx = p->next;
502     p->next = NULL;
503     pbuf_free(p);
504   }
505 
506   state->bio_bytes_read += (int)ret;
507   return ret;
508 }
509 
510 /** Sent callback from lower connection (i.e. TCP)
511  * This only informs the upper layer the number of ACKed bytes.
512  * This now take care of TLS added bytes so application receive
513  * correct ACKed bytes.
514  */
515 static err_t
altcp_mbedtls_lower_sent(void * arg,struct altcp_pcb * inner_conn,u16_t len)516 altcp_mbedtls_lower_sent(void *arg, struct altcp_pcb *inner_conn, u16_t len)
517 {
518   struct altcp_pcb *conn = (struct altcp_pcb *)arg;
519   LWIP_UNUSED_ARG(inner_conn); /* for LWIP_NOASSERT */
520   if (conn) {
521     int overhead;
522     u16_t app_len;
523     altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state;
524     LWIP_ASSERT("state", state != NULL);
525     LWIP_ASSERT("pcb mismatch", conn->inner_conn == inner_conn);
526     /* calculate TLS overhead part to not send it to application */
527     overhead = state->overhead_bytes_adjust + state->ssl_context.out_left;
528     if ((unsigned)overhead > len) {
529       overhead = len;
530     }
531     /* remove ACKed bytes from overhead adjust counter */
532     state->overhead_bytes_adjust -= len;
533     /* try to send more if we failed before (may increase overhead adjust counter) */
534     mbedtls_ssl_flush_output(&state->ssl_context);
535     /* remove calculated overhead from ACKed bytes len */
536     app_len = len - (u16_t)overhead;
537     /* update application write counter and inform application */
538     if (app_len) {
539       state->overhead_bytes_adjust += app_len;
540       if (conn->sent)
541         return conn->sent(conn->arg, conn, app_len);
542     }
543   }
544   return ERR_OK;
545 }
546 
547 /** Poll callback from lower connection (i.e. TCP)
548  * Just pass this on to the application.
549  * @todo: retry sending?
550  */
551 static err_t
altcp_mbedtls_lower_poll(void * arg,struct altcp_pcb * inner_conn)552 altcp_mbedtls_lower_poll(void *arg, struct altcp_pcb *inner_conn)
553 {
554   struct altcp_pcb *conn = (struct altcp_pcb *)arg;
555   LWIP_UNUSED_ARG(inner_conn); /* for LWIP_NOASSERT */
556   if (conn) {
557     LWIP_ASSERT("pcb mismatch", conn->inner_conn == inner_conn);
558     /* check if there's unreceived rx data */
559     if (conn->state) {
560       altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state;
561       /* try to send more if we failed before */
562       mbedtls_ssl_flush_output(&state->ssl_context);
563       if (altcp_mbedtls_handle_rx_appldata(conn, state) == ERR_ABRT) {
564         return ERR_ABRT;
565       }
566     }
567     if (conn->poll) {
568       return conn->poll(conn->arg, conn);
569     }
570   }
571   return ERR_OK;
572 }
573 
574 static void
altcp_mbedtls_lower_err(void * arg,err_t err)575 altcp_mbedtls_lower_err(void *arg, err_t err)
576 {
577   struct altcp_pcb *conn = (struct altcp_pcb *)arg;
578   if (conn) {
579     conn->inner_conn = NULL; /* already freed */
580     if (conn->err) {
581       conn->err(conn->arg, err);
582     }
583     altcp_free(conn);
584   }
585 }
586 
587 /* setup functions */
588 
589 static void
altcp_mbedtls_remove_callbacks(struct altcp_pcb * inner_conn)590 altcp_mbedtls_remove_callbacks(struct altcp_pcb *inner_conn)
591 {
592   altcp_arg(inner_conn, NULL);
593   altcp_recv(inner_conn, NULL);
594   altcp_sent(inner_conn, NULL);
595   altcp_err(inner_conn, NULL);
596   altcp_poll(inner_conn, NULL, inner_conn->pollinterval);
597 }
598 
599 static void
altcp_mbedtls_setup_callbacks(struct altcp_pcb * conn,struct altcp_pcb * inner_conn)600 altcp_mbedtls_setup_callbacks(struct altcp_pcb *conn, struct altcp_pcb *inner_conn)
601 {
602   altcp_arg(inner_conn, conn);
603   altcp_recv(inner_conn, altcp_mbedtls_lower_recv);
604   altcp_sent(inner_conn, altcp_mbedtls_lower_sent);
605   altcp_err(inner_conn, altcp_mbedtls_lower_err);
606   /* tcp_poll is set when interval is set by application */
607   /* listen is set totally different :-) */
608 }
609 
610 static err_t
altcp_mbedtls_setup(void * conf,struct altcp_pcb * conn,struct altcp_pcb * inner_conn)611 altcp_mbedtls_setup(void *conf, struct altcp_pcb *conn, struct altcp_pcb *inner_conn)
612 {
613   int ret;
614   struct altcp_tls_config *config = (struct altcp_tls_config *)conf;
615   altcp_mbedtls_state_t *state;
616   if (!conf) {
617     return ERR_ARG;
618   }
619   LWIP_ASSERT("invalid inner_conn", conn != inner_conn);
620 
621   /* allocate mbedtls context */
622   state = altcp_mbedtls_alloc(conf);
623   if (state == NULL) {
624     return ERR_MEM;
625   }
626   /* initialize mbedtls context: */
627   mbedtls_ssl_init(&state->ssl_context);
628   ret = mbedtls_ssl_setup(&state->ssl_context, &config->conf);
629   if (ret != 0) {
630     LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("mbedtls_ssl_setup failed\n"));
631     /* @todo: convert 'ret' to err_t */
632     altcp_mbedtls_free(conf, state);
633     return ERR_MEM;
634   }
635   /* tell mbedtls about our I/O functions */
636   mbedtls_ssl_set_bio(&state->ssl_context, conn, altcp_mbedtls_bio_send, altcp_mbedtls_bio_recv, NULL);
637 
638   altcp_mbedtls_setup_callbacks(conn, inner_conn);
639   conn->inner_conn = inner_conn;
640   conn->fns = &altcp_mbedtls_functions;
641   conn->state = state;
642   return ERR_OK;
643 }
644 
645 struct altcp_pcb *
altcp_tls_wrap(struct altcp_tls_config * config,struct altcp_pcb * inner_pcb)646 altcp_tls_wrap(struct altcp_tls_config *config, struct altcp_pcb *inner_pcb)
647 {
648   struct altcp_pcb *ret;
649   if (inner_pcb == NULL) {
650     return NULL;
651   }
652   ret = altcp_alloc();
653   if (ret != NULL) {
654     if (altcp_mbedtls_setup(config, ret, inner_pcb) != ERR_OK) {
655       altcp_free(ret);
656       return NULL;
657     }
658   }
659   return ret;
660 }
661 
662 void
altcp_tls_init_session(struct altcp_tls_session * session)663 altcp_tls_init_session(struct altcp_tls_session *session)
664 {
665   if (session)
666     mbedtls_ssl_session_init(&session->data);
667 }
668 
669 err_t
altcp_tls_get_session(struct altcp_pcb * conn,struct altcp_tls_session * session)670 altcp_tls_get_session(struct altcp_pcb *conn, struct altcp_tls_session *session)
671 {
672   if (session && conn && conn->state) {
673     altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state;
674     int ret = mbedtls_ssl_get_session(&state->ssl_context, &session->data);
675     return ret < 0 ? ERR_VAL : ERR_OK;
676   }
677   return ERR_ARG;
678 }
679 
680 err_t
altcp_tls_set_session(struct altcp_pcb * conn,struct altcp_tls_session * session)681 altcp_tls_set_session(struct altcp_pcb *conn, struct altcp_tls_session *session)
682 {
683   if (session && conn && conn->state) {
684     altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state;
685     int ret = -1;
686     if (session->data.start)
687       ret = mbedtls_ssl_set_session(&state->ssl_context, &session->data);
688     return ret < 0 ? ERR_VAL : ERR_OK;
689   }
690   return ERR_ARG;
691 }
692 
693 void
altcp_tls_free_session(struct altcp_tls_session * session)694 altcp_tls_free_session(struct altcp_tls_session *session)
695 {
696   if (session)
697     mbedtls_ssl_session_free(&session->data);
698 }
699 
700 void *
altcp_tls_context(struct altcp_pcb * conn)701 altcp_tls_context(struct altcp_pcb *conn)
702 {
703   if (conn && conn->state) {
704     altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state;
705     return &state->ssl_context;
706   }
707   return NULL;
708 }
709 
710 #if ALTCP_MBEDTLS_LIB_DEBUG != LWIP_DBG_OFF
711 static void
altcp_mbedtls_debug(void * ctx,int level,const char * file,int line,const char * str)712 altcp_mbedtls_debug(void *ctx, int level, const char *file, int line, const char *str)
713 {
714   LWIP_UNUSED_ARG(ctx);
715   LWIP_UNUSED_ARG(file);
716   LWIP_UNUSED_ARG(line);
717   LWIP_UNUSED_ARG(str);
718 
719   if (level >= ALTCP_MBEDTLS_LIB_DEBUG_LEVEL_MIN) {
720     LWIP_DEBUGF(ALTCP_MBEDTLS_LIB_DEBUG, ("%s:%04d: %s\n", file, line, str));
721   }
722 }
723 #endif
724 
725 static err_t
altcp_mbedtls_ref_entropy(void)726 altcp_mbedtls_ref_entropy(void)
727 {
728   LWIP_ASSERT_CORE_LOCKED();
729 
730   if (!altcp_tls_entropy_rng) {
731     altcp_tls_entropy_rng = (struct altcp_tls_entropy_rng *)altcp_mbedtls_alloc_config(sizeof(struct altcp_tls_entropy_rng));
732     if (altcp_tls_entropy_rng) {
733       int ret;
734       altcp_tls_entropy_rng->ref = 1;
735       mbedtls_entropy_init(&altcp_tls_entropy_rng->entropy);
736       mbedtls_ctr_drbg_init(&altcp_tls_entropy_rng->ctr_drbg);
737       /* Seed the RNG, only once */
738       ret = mbedtls_ctr_drbg_seed(&altcp_tls_entropy_rng->ctr_drbg,
739                                   ALTCP_MBEDTLS_RNG_FN, &altcp_tls_entropy_rng->entropy,
740                                   ALTCP_MBEDTLS_ENTROPY_PTR, ALTCP_MBEDTLS_ENTROPY_LEN);
741       if (ret != 0) {
742         LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("mbedtls_ctr_drbg_seed failed: %d\n", ret));
743         mbedtls_ctr_drbg_free(&altcp_tls_entropy_rng->ctr_drbg);
744         mbedtls_entropy_free(&altcp_tls_entropy_rng->entropy);
745         altcp_mbedtls_free_config(altcp_tls_entropy_rng);
746         altcp_tls_entropy_rng = NULL;
747         return ERR_ARG;
748       }
749     } else {
750       return ERR_MEM;
751     }
752   } else {
753     altcp_tls_entropy_rng->ref++;
754   }
755   return ERR_OK;
756 }
757 
758 static void
altcp_mbedtls_unref_entropy(void)759 altcp_mbedtls_unref_entropy(void)
760 {
761   LWIP_ASSERT_CORE_LOCKED();
762 
763   if (altcp_tls_entropy_rng && altcp_tls_entropy_rng->ref) {
764       altcp_tls_entropy_rng->ref--;
765   }
766 }
767 
768 /** Create new TLS configuration
769  * ATTENTION: Server certificate and private key have to be added outside this function!
770  */
771 static struct altcp_tls_config *
altcp_tls_create_config(int is_server,u8_t cert_count,u8_t pkey_count,int have_ca)772 altcp_tls_create_config(int is_server, u8_t cert_count, u8_t pkey_count, int have_ca)
773 {
774   size_t sz;
775   int ret;
776   struct altcp_tls_config *conf;
777   mbedtls_x509_crt *mem;
778 
779   if (TCP_WND < MBEDTLS_SSL_MAX_CONTENT_LEN) {
780     LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG|LWIP_DBG_LEVEL_SERIOUS,
781       ("altcp_tls: TCP_WND is smaller than the RX decrypion buffer, connection RX might stall!\n"));
782   }
783 
784   altcp_mbedtls_mem_init();
785 
786   sz = sizeof(struct altcp_tls_config);
787   if (cert_count > 0) {
788     sz += (cert_count * sizeof(mbedtls_x509_crt));
789   }
790   if (have_ca) {
791     sz += sizeof(mbedtls_x509_crt);
792   }
793   if (pkey_count > 0) {
794     sz += (pkey_count * sizeof(mbedtls_pk_context));
795   }
796 
797   conf = (struct altcp_tls_config *)altcp_mbedtls_alloc_config(sz);
798   if (conf == NULL) {
799     return NULL;
800   }
801   conf->cert_max = cert_count;
802   mem = (mbedtls_x509_crt *)(conf + 1);
803   if (cert_count > 0) {
804     conf->cert = mem;
805     mem += cert_count;
806   }
807   if (have_ca) {
808     conf->ca = mem;
809     mem++;
810   }
811   conf->pkey_max = pkey_count;
812   if (pkey_count > 0) {
813     conf->pkey = (mbedtls_pk_context *)mem;
814   }
815 
816   mbedtls_ssl_config_init(&conf->conf);
817 
818   if (altcp_mbedtls_ref_entropy() != ERR_OK) {
819     altcp_mbedtls_free_config(conf);
820     return NULL;
821   }
822 
823   /* Setup ssl context (@todo: what's different for a client here? -> might better be done on listen/connect) */
824   ret = mbedtls_ssl_config_defaults(&conf->conf, is_server ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT,
825                                     MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
826   if (ret != 0) {
827     LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("mbedtls_ssl_config_defaults failed: %d\n", ret));
828     altcp_mbedtls_unref_entropy();
829     altcp_mbedtls_free_config(conf);
830     return NULL;
831   }
832   mbedtls_ssl_conf_authmode(&conf->conf, ALTCP_MBEDTLS_AUTHMODE);
833 
834   mbedtls_ssl_conf_rng(&conf->conf, mbedtls_ctr_drbg_random, &altcp_tls_entropy_rng->ctr_drbg);
835 #if ALTCP_MBEDTLS_LIB_DEBUG != LWIP_DBG_OFF
836   mbedtls_ssl_conf_dbg(&conf->conf, altcp_mbedtls_debug, stdout);
837 #endif
838 #if defined(MBEDTLS_SSL_CACHE_C) && ALTCP_MBEDTLS_USE_SESSION_CACHE
839   mbedtls_ssl_conf_session_cache(&conf->conf, &conf->cache, mbedtls_ssl_cache_get, mbedtls_ssl_cache_set);
840   mbedtls_ssl_cache_set_timeout(&conf->cache, ALTCP_MBEDTLS_SESSION_CACHE_TIMEOUT_SECONDS);
841   mbedtls_ssl_cache_set_max_entries(&conf->cache, ALTCP_MBEDTLS_SESSION_CACHE_SIZE);
842 #endif
843 
844 #if defined(MBEDTLS_SSL_SESSION_TICKETS) && ALTCP_MBEDTLS_USE_SESSION_TICKETS
845   mbedtls_ssl_ticket_init(&conf->ticket_ctx);
846 
847   ret = mbedtls_ssl_ticket_setup(&conf->ticket_ctx, mbedtls_ctr_drbg_random, &altcp_tls_entropy_rng->ctr_drbg,
848     ALTCP_MBEDTLS_SESSION_TICKET_CIPHER, ALTCP_MBEDTLS_SESSION_TICKET_TIMEOUT_SECONDS);
849   if (ret) {
850     LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("mbedtls_ssl_ticket_setup failed: %d\n", ret));
851     altcp_mbedtls_unref_entropy();
852     altcp_mbedtls_free_config(conf);
853     return NULL;
854   }
855 
856   mbedtls_ssl_conf_session_tickets_cb(&conf->conf, mbedtls_ssl_ticket_write, mbedtls_ssl_ticket_parse,
857     &conf->ticket_ctx);
858 #endif
859 
860   return conf;
861 }
862 
altcp_tls_create_config_server(u8_t cert_count)863 struct altcp_tls_config *altcp_tls_create_config_server(u8_t cert_count)
864 {
865   struct altcp_tls_config *conf = altcp_tls_create_config(1, cert_count, cert_count, 0);
866   if (conf == NULL) {
867     return NULL;
868   }
869 
870   mbedtls_ssl_conf_ca_chain(&conf->conf, NULL, NULL);
871   return conf;
872 }
873 
altcp_tls_config_server_add_privkey_cert(struct altcp_tls_config * config,const u8_t * privkey,size_t privkey_len,const u8_t * privkey_pass,size_t privkey_pass_len,const u8_t * cert,size_t cert_len)874 err_t altcp_tls_config_server_add_privkey_cert(struct altcp_tls_config *config,
875       const u8_t *privkey, size_t privkey_len,
876       const u8_t *privkey_pass, size_t privkey_pass_len,
877       const u8_t *cert, size_t cert_len)
878 {
879   int ret;
880   mbedtls_x509_crt *srvcert;
881   mbedtls_pk_context *pkey;
882 
883   if (config->cert_count >= config->cert_max) {
884     return ERR_MEM;
885   }
886   if (config->pkey_count >= config->pkey_max) {
887     return ERR_MEM;
888   }
889 
890   srvcert = config->cert + config->cert_count;
891   mbedtls_x509_crt_init(srvcert);
892 
893   pkey = config->pkey + config->pkey_count;
894   mbedtls_pk_init(pkey);
895 
896   /* Load the certificates and private key */
897   ret = mbedtls_x509_crt_parse(srvcert, cert, cert_len);
898   if (ret != 0) {
899     LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("mbedtls_x509_crt_parse failed: %d\n", ret));
900     return ERR_VAL;
901   }
902 
903   ret = mbedtls_pk_parse_key(pkey, (const unsigned char *) privkey, privkey_len, privkey_pass, privkey_pass_len);
904   if (ret != 0) {
905     LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("mbedtls_pk_parse_public_key failed: %d\n", ret));
906     mbedtls_x509_crt_free(srvcert);
907     return ERR_VAL;
908   }
909 
910   ret = mbedtls_ssl_conf_own_cert(&config->conf, srvcert, pkey);
911   if (ret != 0) {
912     LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("mbedtls_ssl_conf_own_cert failed: %d\n", ret));
913     mbedtls_x509_crt_free(srvcert);
914     mbedtls_pk_free(pkey);
915     return ERR_VAL;
916   }
917 
918   config->cert_count++;
919   config->pkey_count++;
920   return ERR_OK;
921 }
922 
923 /** Create new TLS configuration
924  * This is a suboptimal version that gets the encrypted private key and its password,
925  * as well as the server certificate.
926  */
927 struct altcp_tls_config *
altcp_tls_create_config_server_privkey_cert(const u8_t * privkey,size_t privkey_len,const u8_t * privkey_pass,size_t privkey_pass_len,const u8_t * cert,size_t cert_len)928 altcp_tls_create_config_server_privkey_cert(const u8_t *privkey, size_t privkey_len,
929     const u8_t *privkey_pass, size_t privkey_pass_len,
930     const u8_t *cert, size_t cert_len)
931 {
932   struct altcp_tls_config *conf = altcp_tls_create_config_server(1);
933   if (conf == NULL) {
934     return NULL;
935   }
936 
937   if (altcp_tls_config_server_add_privkey_cert(conf, privkey, privkey_len,
938     privkey_pass, privkey_pass_len, cert, cert_len) != ERR_OK) {
939     altcp_tls_free_config(conf);
940     return NULL;
941   }
942 
943   return conf;
944 }
945 
946 static struct altcp_tls_config *
altcp_tls_create_config_client_common(const u8_t * ca,size_t ca_len,int is_2wayauth)947 altcp_tls_create_config_client_common(const u8_t *ca, size_t ca_len, int is_2wayauth)
948 {
949   int ret;
950   struct altcp_tls_config *conf = altcp_tls_create_config(0, (is_2wayauth) ? 1 : 0, (is_2wayauth) ? 1 : 0, ca != NULL);
951   if (conf == NULL) {
952     return NULL;
953   }
954 
955   /* Initialize the CA certificate if provided
956    * CA certificate is optional (to save memory) but recommended for production environment
957    * Without CA certificate, connection will be prone to man-in-the-middle attacks */
958   if (ca) {
959     mbedtls_x509_crt_init(conf->ca);
960     ret = mbedtls_x509_crt_parse(conf->ca, ca, ca_len);
961     if (ret != 0) {
962       LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("mbedtls_x509_crt_parse ca failed: %d 0x%x\n", ret, -1*ret));
963       altcp_tls_free_config(conf);
964       return NULL;
965     }
966 
967     mbedtls_ssl_conf_ca_chain(&conf->conf, conf->ca, NULL);
968   }
969   return conf;
970 }
971 
972 struct altcp_tls_config *
altcp_tls_create_config_client(const u8_t * ca,size_t ca_len)973 altcp_tls_create_config_client(const u8_t *ca, size_t ca_len)
974 {
975   return altcp_tls_create_config_client_common(ca, ca_len, 0);
976 }
977 
978 struct altcp_tls_config *
altcp_tls_create_config_client_2wayauth(const u8_t * ca,size_t ca_len,const u8_t * privkey,size_t privkey_len,const u8_t * privkey_pass,size_t privkey_pass_len,const u8_t * cert,size_t cert_len)979 altcp_tls_create_config_client_2wayauth(const u8_t *ca, size_t ca_len, const u8_t *privkey, size_t privkey_len,
980                                         const u8_t *privkey_pass, size_t privkey_pass_len,
981                                         const u8_t *cert, size_t cert_len)
982 {
983   int ret;
984   struct altcp_tls_config *conf;
985 
986   if (!cert || !privkey) {
987     LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("altcp_tls_create_config_client_2wayauth: certificate and priv key required\n"));
988     return NULL;
989   }
990 
991   conf = altcp_tls_create_config_client_common(ca, ca_len, 1);
992   if (conf == NULL) {
993     return NULL;
994   }
995 
996   /* Initialize the client certificate and corresponding private key */
997   mbedtls_x509_crt_init(conf->cert);
998   ret = mbedtls_x509_crt_parse(conf->cert, cert, cert_len);
999   if (ret != 0) {
1000     LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("mbedtls_x509_crt_parse cert failed: %d 0x%x\n", ret, -1*ret));
1001     altcp_tls_free_config(conf);
1002     return NULL;
1003   }
1004 
1005   mbedtls_pk_init(conf->pkey);
1006   ret = mbedtls_pk_parse_key(conf->pkey, privkey, privkey_len, privkey_pass, privkey_pass_len);
1007   if (ret != 0) {
1008     LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("mbedtls_pk_parse_key failed: %d 0x%x\n", ret, -1*ret));
1009     altcp_tls_free_config(conf);
1010     return NULL;
1011   }
1012 
1013   ret = mbedtls_ssl_conf_own_cert(&conf->conf, conf->cert, conf->pkey);
1014   if (ret != 0) {
1015     LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("mbedtls_ssl_conf_own_cert failed: %d 0x%x\n", ret, -1*ret));
1016     altcp_tls_free_config(conf);
1017     return NULL;
1018   }
1019 
1020   return conf;
1021 }
1022 
1023 int
altcp_tls_configure_alpn_protocols(struct altcp_tls_config * conf,const char ** protos)1024 altcp_tls_configure_alpn_protocols(struct altcp_tls_config *conf, const char **protos)
1025 {
1026 #if defined(MBEDTLS_SSL_ALPN)
1027   int ret = mbedtls_ssl_conf_alpn_protocols(&conf->conf, protos);
1028   if (ret != 0) {
1029     LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("mbedtls_ssl_conf_alpn_protocols failed: %d\n", ret));
1030   }
1031 
1032   return ret;
1033 #else
1034   return -1;
1035 #endif
1036 }
1037 
1038 void
altcp_tls_free_config(struct altcp_tls_config * conf)1039 altcp_tls_free_config(struct altcp_tls_config *conf)
1040 {
1041   if (conf->pkey) {
1042     mbedtls_pk_free(conf->pkey);
1043   }
1044   if (conf->cert) {
1045     mbedtls_x509_crt_free(conf->cert);
1046   }
1047   if (conf->ca) {
1048     mbedtls_x509_crt_free(conf->ca);
1049   }
1050   mbedtls_ssl_config_free(&conf->conf);
1051   altcp_mbedtls_free_config(conf);
1052   altcp_mbedtls_unref_entropy();
1053 }
1054 
1055 void
altcp_tls_free_entropy(void)1056 altcp_tls_free_entropy(void)
1057 {
1058   LWIP_ASSERT_CORE_LOCKED();
1059 
1060   if (altcp_tls_entropy_rng && altcp_tls_entropy_rng->ref == 0) {
1061     mbedtls_ctr_drbg_free(&altcp_tls_entropy_rng->ctr_drbg);
1062     mbedtls_entropy_free(&altcp_tls_entropy_rng->entropy);
1063     altcp_mbedtls_free_config(altcp_tls_entropy_rng);
1064     altcp_tls_entropy_rng = NULL;
1065   }
1066 }
1067 
1068 /* "virtual" functions */
1069 static void
altcp_mbedtls_set_poll(struct altcp_pcb * conn,u8_t interval)1070 altcp_mbedtls_set_poll(struct altcp_pcb *conn, u8_t interval)
1071 {
1072   if (conn != NULL) {
1073     altcp_poll(conn->inner_conn, altcp_mbedtls_lower_poll, interval);
1074   }
1075 }
1076 
1077 static void
altcp_mbedtls_recved(struct altcp_pcb * conn,u16_t len)1078 altcp_mbedtls_recved(struct altcp_pcb *conn, u16_t len)
1079 {
1080   u16_t lower_recved;
1081   altcp_mbedtls_state_t *state;
1082   if (conn == NULL) {
1083     return;
1084   }
1085   state = (altcp_mbedtls_state_t *)conn->state;
1086   if (state == NULL) {
1087     return;
1088   }
1089   if (!(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) {
1090     return;
1091   }
1092   lower_recved = len;
1093   if (lower_recved > state->rx_passed_unrecved) {
1094     LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("bogus recved count (len > state->rx_passed_unrecved / %d / %d)\n",
1095                                       len, state->rx_passed_unrecved));
1096     lower_recved = (u16_t)state->rx_passed_unrecved;
1097   }
1098   state->rx_passed_unrecved -= lower_recved;
1099 
1100   altcp_recved(conn->inner_conn, lower_recved);
1101 }
1102 
1103 static err_t
altcp_mbedtls_connect(struct altcp_pcb * conn,const ip_addr_t * ipaddr,u16_t port,altcp_connected_fn connected)1104 altcp_mbedtls_connect(struct altcp_pcb *conn, const ip_addr_t *ipaddr, u16_t port, altcp_connected_fn connected)
1105 {
1106   if (conn == NULL) {
1107     return ERR_VAL;
1108   }
1109   conn->connected = connected;
1110   return altcp_connect(conn->inner_conn, ipaddr, port, altcp_mbedtls_lower_connected);
1111 }
1112 
1113 static struct altcp_pcb *
altcp_mbedtls_listen(struct altcp_pcb * conn,u8_t backlog,err_t * err)1114 altcp_mbedtls_listen(struct altcp_pcb *conn, u8_t backlog, err_t *err)
1115 {
1116   struct altcp_pcb *lpcb;
1117   if (conn == NULL) {
1118     return NULL;
1119   }
1120   lpcb = altcp_listen_with_backlog_and_err(conn->inner_conn, backlog, err);
1121   if (lpcb != NULL) {
1122     altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state;
1123     /* Free members of the ssl context (not used on listening pcb). This
1124        includes freeing input/output buffers, so saves ~32KByte by default */
1125     mbedtls_ssl_free(&state->ssl_context);
1126 
1127     conn->inner_conn = lpcb;
1128     altcp_accept(lpcb, altcp_mbedtls_lower_accept);
1129     return conn;
1130   }
1131   return NULL;
1132 }
1133 
1134 static void
altcp_mbedtls_abort(struct altcp_pcb * conn)1135 altcp_mbedtls_abort(struct altcp_pcb *conn)
1136 {
1137   if (conn != NULL) {
1138     altcp_abort(conn->inner_conn);
1139   }
1140 }
1141 
1142 static err_t
altcp_mbedtls_close(struct altcp_pcb * conn)1143 altcp_mbedtls_close(struct altcp_pcb *conn)
1144 {
1145   struct altcp_pcb *inner_conn;
1146   if (conn == NULL) {
1147     return ERR_VAL;
1148   }
1149   inner_conn = conn->inner_conn;
1150   if (inner_conn) {
1151     err_t err;
1152     altcp_poll_fn oldpoll = inner_conn->poll;
1153     altcp_mbedtls_remove_callbacks(conn->inner_conn);
1154     err = altcp_close(conn->inner_conn);
1155     if (err != ERR_OK) {
1156       /* not closed, set up all callbacks again */
1157       altcp_mbedtls_setup_callbacks(conn, inner_conn);
1158       /* poll callback is not included in the above */
1159       altcp_poll(inner_conn, oldpoll, inner_conn->pollinterval);
1160       return err;
1161     }
1162     conn->inner_conn = NULL;
1163   }
1164   altcp_free(conn);
1165   return ERR_OK;
1166 }
1167 
1168 /** Allow caller of altcp_write() to limit to negotiated chunk size
1169  *  or remaining sndbuf space of inner_conn.
1170  */
1171 static u16_t
altcp_mbedtls_sndbuf(struct altcp_pcb * conn)1172 altcp_mbedtls_sndbuf(struct altcp_pcb *conn)
1173 {
1174   if (conn) {
1175     altcp_mbedtls_state_t *state;
1176     state = (altcp_mbedtls_state_t*)conn->state;
1177     if (!state || !(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) {
1178       return 0;
1179     }
1180     if (conn->inner_conn) {
1181       u16_t sndbuf = altcp_sndbuf(conn->inner_conn);
1182       /* Take care of record header, IV, AuthTag */
1183       int ssl_expan = mbedtls_ssl_get_record_expansion(&state->ssl_context);
1184       if (ssl_expan > 0) {
1185         size_t ssl_added = (u16_t)LWIP_MIN(ssl_expan, 0xFFFF);
1186         /* internal sndbuf smaller than our offset */
1187         if (ssl_added < sndbuf) {
1188           size_t max_len = 0xFFFF;
1189           size_t ret;
1190 #if defined(MBEDTLS_SSL_MAX_FRAGMENT_LENGTH)
1191           /* @todo: adjust ssl_added to real value related to negotiated cipher */
1192           size_t max_frag_len = mbedtls_ssl_get_max_frag_len(&state->ssl_context);
1193           max_len = LWIP_MIN(max_frag_len, max_len);
1194 #endif
1195           /* Adjust sndbuf of inner_conn with what added by SSL */
1196           ret = LWIP_MIN(sndbuf - ssl_added, max_len);
1197           LWIP_ASSERT("sndbuf overflow", ret <= 0xFFFF);
1198           return (u16_t)ret;
1199         }
1200       }
1201     }
1202   }
1203   /* fallback: use sendbuf of the inner connection */
1204   return altcp_default_sndbuf(conn);
1205 }
1206 
1207 /** Write data to a TLS connection. Calls into mbedTLS, which in turn calls into
1208  * @ref altcp_mbedtls_bio_send() to send the encrypted data
1209  */
1210 static err_t
altcp_mbedtls_write(struct altcp_pcb * conn,const void * dataptr,u16_t len,u8_t apiflags)1211 altcp_mbedtls_write(struct altcp_pcb *conn, const void *dataptr, u16_t len, u8_t apiflags)
1212 {
1213   int ret;
1214   altcp_mbedtls_state_t *state;
1215 
1216   LWIP_UNUSED_ARG(apiflags);
1217 
1218   if (conn == NULL) {
1219     return ERR_VAL;
1220   }
1221 
1222   state = (altcp_mbedtls_state_t *)conn->state;
1223   if (state == NULL) {
1224     /* @todo: which error? */
1225     return ERR_ARG;
1226   }
1227   if (!(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) {
1228     /* @todo: which error? */
1229     return ERR_VAL;
1230   }
1231 
1232   /* HACK: if there is something left to send, try to flush it and only
1233      allow sending more if this succeeded (this is a hack because neither
1234      returning 0 nor MBEDTLS_ERR_SSL_WANT_WRITE worked for me) */
1235   if (state->ssl_context.out_left) {
1236     mbedtls_ssl_flush_output(&state->ssl_context);
1237     if (state->ssl_context.out_left) {
1238       return ERR_MEM;
1239     }
1240   }
1241   ret = mbedtls_ssl_write(&state->ssl_context, (const unsigned char *)dataptr, len);
1242   /* try to send data... */
1243   altcp_output(conn->inner_conn);
1244   if (ret >= 0) {
1245     if (ret == len) {
1246       /* update application sent counter */
1247       state->overhead_bytes_adjust -= ret;
1248       return ERR_OK;
1249     } else {
1250       /* @todo/@fixme: assumption: either everything sent or error */
1251       LWIP_ASSERT("ret <= 0", 0);
1252       return ERR_MEM;
1253     }
1254   } else {
1255     if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
1256       /* @todo: convert error to err_t */
1257       return ERR_MEM;
1258     }
1259     LWIP_ASSERT("unhandled error", 0);
1260     return ERR_VAL;
1261   }
1262 }
1263 
1264 /** Send callback function called from mbedtls (set via mbedtls_ssl_set_bio)
1265  * This function is either called during handshake or when sending application
1266  * data via @ref altcp_mbedtls_write (or altcp_write)
1267  */
1268 static int
altcp_mbedtls_bio_send(void * ctx,const unsigned char * dataptr,size_t size)1269 altcp_mbedtls_bio_send(void *ctx, const unsigned char *dataptr, size_t size)
1270 {
1271   struct altcp_pcb *conn = (struct altcp_pcb *) ctx;
1272   altcp_mbedtls_state_t *state;
1273   int written = 0;
1274   size_t size_left = size;
1275   u8_t apiflags = TCP_WRITE_FLAG_COPY;
1276 
1277   LWIP_ASSERT("conn != NULL", conn != NULL);
1278   if ((conn == NULL) || (conn->inner_conn == NULL)) {
1279     return MBEDTLS_ERR_NET_INVALID_CONTEXT;
1280   }
1281   state = (altcp_mbedtls_state_t *)conn->state;
1282   LWIP_ASSERT("state != NULL", state != NULL);
1283 
1284   while (size_left) {
1285     u16_t write_len = (u16_t)LWIP_MIN(size_left, 0xFFFF);
1286     err_t err = altcp_write(conn->inner_conn, (const void *)dataptr, write_len, apiflags);
1287     if (err == ERR_OK) {
1288       written += write_len;
1289       size_left -= write_len;
1290       state->overhead_bytes_adjust += write_len;
1291     } else if (err == ERR_MEM) {
1292       if (written) {
1293         return written;
1294       }
1295       return 0; /* MBEDTLS_ERR_SSL_WANT_WRITE; */
1296     } else {
1297       LWIP_ASSERT("tls_write, tcp_write: err != ERR MEM", 0);
1298       /* @todo: return MBEDTLS_ERR_NET_CONN_RESET or MBEDTLS_ERR_NET_SEND_FAILED */
1299       return MBEDTLS_ERR_NET_SEND_FAILED;
1300     }
1301   }
1302   return written;
1303 }
1304 
1305 static u16_t
altcp_mbedtls_mss(struct altcp_pcb * conn)1306 altcp_mbedtls_mss(struct altcp_pcb *conn)
1307 {
1308   if (conn == NULL) {
1309     return 0;
1310   }
1311   /* @todo: LWIP_MIN(mss, mbedtls_ssl_get_max_frag_len()) ? */
1312   return altcp_mss(conn->inner_conn);
1313 }
1314 
1315 static void
altcp_mbedtls_dealloc(struct altcp_pcb * conn)1316 altcp_mbedtls_dealloc(struct altcp_pcb *conn)
1317 {
1318   /* clean up and free tls state */
1319   if (conn) {
1320     altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state;
1321     if (state) {
1322       mbedtls_ssl_free(&state->ssl_context);
1323       state->flags = 0;
1324       if (state->rx) {
1325         /* free leftover (unhandled) rx pbufs */
1326         pbuf_free(state->rx);
1327         state->rx = NULL;
1328       }
1329       altcp_mbedtls_free(state->conf, state);
1330       conn->state = NULL;
1331     }
1332   }
1333 }
1334 
1335 const struct altcp_functions altcp_mbedtls_functions = {
1336   altcp_mbedtls_set_poll,
1337   altcp_mbedtls_recved,
1338   altcp_default_bind,
1339   altcp_mbedtls_connect,
1340   altcp_mbedtls_listen,
1341   altcp_mbedtls_abort,
1342   altcp_mbedtls_close,
1343   altcp_default_shutdown,
1344   altcp_mbedtls_write,
1345   altcp_default_output,
1346   altcp_mbedtls_mss,
1347   altcp_mbedtls_sndbuf,
1348   altcp_default_sndqueuelen,
1349   altcp_default_nagle_disable,
1350   altcp_default_nagle_enable,
1351   altcp_default_nagle_disabled,
1352   altcp_default_setprio,
1353   altcp_mbedtls_dealloc,
1354   altcp_default_get_tcp_addrinfo,
1355   altcp_default_get_ip,
1356   altcp_default_get_port
1357 #if LWIP_TCP_KEEPALIVE
1358   , altcp_default_keepalive_disable
1359   , altcp_default_keepalive_enable
1360 #endif
1361 #ifdef LWIP_DEBUG
1362   , altcp_default_dbg_get_tcp_state
1363 #endif
1364 };
1365 
1366 #endif /* LWIP_ALTCP_TLS && LWIP_ALTCP_TLS_MBEDTLS */
1367 #endif /* LWIP_ALTCP */
1368