xref: /reactos/drivers/network/tcpip/ip/lwip_glue/tcp.c (revision d6eebaa4)
1 #include <debug.h>
2 #include <lwip/tcpip.h>
3 
4 #include "lwip_glue.h"
5 
6 static const char * const tcp_state_str[] = {
7   "CLOSED",
8   "LISTEN",
9   "SYN_SENT",
10   "SYN_RCVD",
11   "ESTABLISHED",
12   "FIN_WAIT_1",
13   "FIN_WAIT_2",
14   "CLOSE_WAIT",
15   "CLOSING",
16   "LAST_ACK",
17   "TIME_WAIT"
18 };
19 
20 /* The way that lwIP does multi-threading is really not ideal for our purposes but
21  * we best go along with it unless we want another unstable TCP library. lwIP uses
22  * a thread called the "tcpip thread" which is the only one allowed to call raw API
23  * functions. Since this is the case, for each of our LibTCP* functions, we queue a request
24  * for a callback to "tcpip thread" which calls our LibTCP*Callback functions. Yes, this is
25  * a lot of unnecessary thread swapping and it could definitely be faster, but I don't want
26  * to going messing around in lwIP because I have no desire to create another mess like oskittcp */
27 
28 extern KEVENT TerminationEvent;
29 extern NPAGED_LOOKASIDE_LIST MessageLookasideList;
30 extern NPAGED_LOOKASIDE_LIST QueueEntryLookasideList;
31 
32 /* Required for ERR_T to NTSTATUS translation in receive error handling */
33 NTSTATUS TCPTranslateError(const err_t err);
34 
35 void
LibTCPDumpPcb(PVOID SocketContext)36 LibTCPDumpPcb(PVOID SocketContext)
37 {
38     struct tcp_pcb *pcb = (struct tcp_pcb*)SocketContext;
39     unsigned int addr = lwip_ntohl(pcb->remote_ip.addr);
40 
41     DbgPrint("\tState: %s\n", tcp_state_str[pcb->state]);
42     DbgPrint("\tRemote: (%d.%d.%d.%d, %d)\n",
43     (addr >> 24) & 0xFF,
44     (addr >> 16) & 0xFF,
45     (addr >> 8) & 0xFF,
46     addr & 0xFF,
47     pcb->remote_port);
48 }
49 
50 static
51 void
LibTCPEmptyQueue(PCONNECTION_ENDPOINT Connection)52 LibTCPEmptyQueue(PCONNECTION_ENDPOINT Connection)
53 {
54     PLIST_ENTRY Entry;
55     PQUEUE_ENTRY qp = NULL;
56 
57     ReferenceObject(Connection);
58 
59     while (!IsListEmpty(&Connection->PacketQueue))
60     {
61         Entry = RemoveHeadList(&Connection->PacketQueue);
62         qp = CONTAINING_RECORD(Entry, QUEUE_ENTRY, ListEntry);
63 
64         /* We're in the tcpip thread here so this is safe */
65         pbuf_free(qp->p);
66 
67         ExFreeToNPagedLookasideList(&QueueEntryLookasideList, qp);
68     }
69 
70     DereferenceObject(Connection);
71 }
72 
LibTCPEnqueuePacket(PCONNECTION_ENDPOINT Connection,struct pbuf * p)73 void LibTCPEnqueuePacket(PCONNECTION_ENDPOINT Connection, struct pbuf *p)
74 {
75     PQUEUE_ENTRY qp;
76 
77     qp = (PQUEUE_ENTRY)ExAllocateFromNPagedLookasideList(&QueueEntryLookasideList);
78     qp->p = p;
79     qp->Offset = 0;
80 
81     LockObject(Connection);
82     InsertTailList(&Connection->PacketQueue, &qp->ListEntry);
83     UnlockObject(Connection);
84 }
85 
LibTCPDequeuePacket(PCONNECTION_ENDPOINT Connection)86 PQUEUE_ENTRY LibTCPDequeuePacket(PCONNECTION_ENDPOINT Connection)
87 {
88     PLIST_ENTRY Entry;
89     PQUEUE_ENTRY qp = NULL;
90 
91     if (IsListEmpty(&Connection->PacketQueue)) return NULL;
92 
93     Entry = RemoveHeadList(&Connection->PacketQueue);
94 
95     qp = CONTAINING_RECORD(Entry, QUEUE_ENTRY, ListEntry);
96 
97     return qp;
98 }
99 
LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection,PUCHAR RecvBuffer,UINT RecvLen,UINT * Received)100 NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHAR RecvBuffer, UINT RecvLen, UINT *Received)
101 {
102     PQUEUE_ENTRY qp;
103     struct pbuf* p;
104     NTSTATUS Status;
105     UINT ReadLength, PayloadLength, Offset, Copied;
106 
107     (*Received) = 0;
108 
109     LockObject(Connection);
110 
111     if (!IsListEmpty(&Connection->PacketQueue))
112     {
113         while ((qp = LibTCPDequeuePacket(Connection)) != NULL)
114         {
115             p = qp->p;
116 
117             /* Calculate the payload length first */
118             PayloadLength = p->tot_len;
119             PayloadLength -= qp->Offset;
120             Offset = qp->Offset;
121 
122             /* Check if we're reading the whole buffer */
123             ReadLength = MIN(PayloadLength, RecvLen);
124             ASSERT(ReadLength != 0);
125             if (ReadLength != PayloadLength)
126             {
127                 /* Save this one for later */
128                 qp->Offset += ReadLength;
129                 InsertHeadList(&Connection->PacketQueue, &qp->ListEntry);
130                 qp = NULL;
131             }
132 
133             Copied = pbuf_copy_partial(p, RecvBuffer, ReadLength, Offset);
134             ASSERT(Copied == ReadLength);
135 
136             /* Update trackers */
137             RecvLen -= ReadLength;
138             RecvBuffer += ReadLength;
139             (*Received) += ReadLength;
140 
141             if (qp != NULL)
142             {
143                 /* Use this special pbuf free callback function because we're outside tcpip thread */
144                 pbuf_free_callback(qp->p);
145 
146                 ExFreeToNPagedLookasideList(&QueueEntryLookasideList, qp);
147             }
148             else
149             {
150                 /* If we get here, it means we've filled the buffer */
151                 ASSERT(RecvLen == 0);
152             }
153 
154             ASSERT((*Received) != 0);
155             Status = STATUS_SUCCESS;
156 
157             if (!RecvLen)
158                 break;
159         }
160     }
161     else
162     {
163         if (Connection->ReceiveShutdown)
164             Status = Connection->ReceiveShutdownStatus;
165         else
166             Status = STATUS_PENDING;
167     }
168 
169     UnlockObject(Connection);
170 
171     return Status;
172 }
173 
174 static
175 BOOLEAN
WaitForEventSafely(PRKEVENT Event)176 WaitForEventSafely(PRKEVENT Event)
177 {
178     PVOID WaitObjects[] = {Event, &TerminationEvent};
179 
180     if (KeWaitForMultipleObjects(2,
181                                  WaitObjects,
182                                  WaitAny,
183                                  Executive,
184                                  KernelMode,
185                                  FALSE,
186                                  NULL,
187                                  NULL) == STATUS_WAIT_0)
188     {
189         /* Signalled by the caller's event */
190         return TRUE;
191     }
192     else /* if KeWaitForMultipleObjects() == STATUS_WAIT_1 */
193     {
194         /* Signalled by our termination event */
195         return FALSE;
196     }
197 }
198 
199 static
200 err_t
InternalSendEventHandler(void * arg,PTCP_PCB pcb,const u16_t space)201 InternalSendEventHandler(void *arg, PTCP_PCB pcb, const u16_t space)
202 {
203     /* Make sure the socket didn't get closed */
204     if (!arg) return ERR_OK;
205 
206     TCPSendEventHandler(arg, space);
207 
208     return ERR_OK;
209 }
210 
211 static
212 err_t
InternalRecvEventHandler(void * arg,PTCP_PCB pcb,struct pbuf * p,const err_t err)213 InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t err)
214 {
215     PCONNECTION_ENDPOINT Connection = arg;
216 
217     /* Make sure the socket didn't get closed */
218     if (!arg)
219     {
220         if (p)
221             pbuf_free(p);
222 
223         return ERR_OK;
224     }
225 
226     if (p)
227     {
228         LibTCPEnqueuePacket(Connection, p);
229 
230         tcp_recved(pcb, p->tot_len);
231 
232         TCPRecvEventHandler(arg);
233     }
234     else if (err == ERR_OK)
235     {
236         /* Complete pending reads with 0 bytes to indicate a graceful closure,
237          * but note that send is still possible in this state so we don't close the
238          * whole socket here (by calling tcp_close()) as that would violate TCP specs
239          */
240         Connection->ReceiveShutdown = TRUE;
241         Connection->ReceiveShutdownStatus = STATUS_SUCCESS;
242 
243         /* If we already did a send shutdown, we're in TIME_WAIT so we can't use this PCB anymore */
244         if (Connection->SendShutdown)
245         {
246             Connection->SocketContext = NULL;
247             tcp_arg(pcb, NULL);
248         }
249 
250         /* Indicate the graceful close event */
251         TCPRecvEventHandler(arg);
252 
253         /* If the PCB is gone, clean up the connection */
254         if (Connection->SendShutdown)
255         {
256             TCPFinEventHandler(Connection, ERR_CLSD);
257         }
258     }
259 
260     return ERR_OK;
261 }
262 
263 /* This function MUST return an error value that is not ERR_ABRT or ERR_OK if the connection
264  * is not accepted to avoid leaking the new PCB */
265 static
266 err_t
InternalAcceptEventHandler(void * arg,PTCP_PCB newpcb,const err_t err)267 InternalAcceptEventHandler(void *arg, PTCP_PCB newpcb, const err_t err)
268 {
269     /* Make sure the socket didn't get closed */
270     if (!arg)
271         return ERR_CLSD;
272 
273     TCPAcceptEventHandler(arg, newpcb);
274 
275     /* Set in LibTCPAccept (called from TCPAcceptEventHandler) */
276     if (newpcb->callback_arg)
277         return ERR_OK;
278     else
279         return ERR_CLSD;
280 }
281 
282 static
283 err_t
InternalConnectEventHandler(void * arg,PTCP_PCB pcb,const err_t err)284 InternalConnectEventHandler(void *arg, PTCP_PCB pcb, const err_t err)
285 {
286     /* Make sure the socket didn't get closed */
287     if (!arg)
288         return ERR_OK;
289 
290     TCPConnectEventHandler(arg, err);
291 
292     return ERR_OK;
293 }
294 
295 static
296 void
InternalErrorEventHandler(void * arg,const err_t err)297 InternalErrorEventHandler(void *arg, const err_t err)
298 {
299     PCONNECTION_ENDPOINT Connection = arg;
300 
301     /* Make sure the socket didn't get closed */
302     if (!arg || Connection->SocketContext == NULL) return;
303 
304     /* The PCB is dead now */
305     Connection->SocketContext = NULL;
306 
307     /* Give them one shot to receive the remaining data */
308     Connection->ReceiveShutdown = TRUE;
309     Connection->ReceiveShutdownStatus = TCPTranslateError(err);
310     TCPRecvEventHandler(Connection);
311 
312     /* Terminate the connection */
313     TCPFinEventHandler(Connection, err);
314 }
315 
316 static
317 void
LibTCPSocketCallback(void * arg)318 LibTCPSocketCallback(void *arg)
319 {
320     struct lwip_callback_msg *msg = arg;
321 
322     ASSERT(msg);
323 
324     msg->Output.Socket.NewPcb = tcp_new();
325 
326     if (msg->Output.Socket.NewPcb)
327     {
328         tcp_arg(msg->Output.Socket.NewPcb, msg->Input.Socket.Arg);
329         tcp_err(msg->Output.Socket.NewPcb, InternalErrorEventHandler);
330     }
331 
332     KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
333 }
334 
335 struct tcp_pcb *
LibTCPSocket(void * arg)336 LibTCPSocket(void *arg)
337 {
338     struct lwip_callback_msg *msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
339     struct tcp_pcb *ret;
340 
341     if (msg)
342     {
343         KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
344         msg->Input.Socket.Arg = arg;
345 
346         tcpip_callback_with_block(LibTCPSocketCallback, msg, 1);
347 
348         if (WaitForEventSafely(&msg->Event))
349             ret = msg->Output.Socket.NewPcb;
350         else
351             ret = NULL;
352 
353         ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
354 
355         return ret;
356     }
357 
358     return NULL;
359 }
360 
361 static
362 void
LibTCPFreeSocketCallback(void * arg)363 LibTCPFreeSocketCallback(void *arg)
364 {
365     struct lwip_callback_msg *msg = arg;
366 
367     ASSERT(msg);
368 
369     /* Calling tcp_close will free it */
370     tcp_close(msg->Input.FreeSocket.pcb);
371 
372     KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
373 }
374 
LibTCPFreeSocket(PTCP_PCB pcb)375 void LibTCPFreeSocket(PTCP_PCB pcb)
376 {
377     struct lwip_callback_msg msg;
378 
379     KeInitializeEvent(&msg.Event, NotificationEvent, FALSE);
380     msg.Input.FreeSocket.pcb = pcb;
381 
382     tcpip_callback_with_block(LibTCPFreeSocketCallback, &msg, 1);
383 
384     WaitForEventSafely(&msg.Event);
385 }
386 
387 
388 static
389 void
LibTCPBindCallback(void * arg)390 LibTCPBindCallback(void *arg)
391 {
392     struct lwip_callback_msg *msg = arg;
393     PTCP_PCB pcb = msg->Input.Bind.Connection->SocketContext;
394 
395     ASSERT(msg);
396 
397     if (!msg->Input.Bind.Connection->SocketContext)
398     {
399         msg->Output.Bind.Error = ERR_CLSD;
400         goto done;
401     }
402 
403     /* We're guaranteed that the local address is valid to bind at this point */
404     pcb->so_options |= SOF_REUSEADDR;
405 
406     msg->Output.Bind.Error = tcp_bind(pcb,
407                                       msg->Input.Bind.IpAddress,
408                                       lwip_ntohs(msg->Input.Bind.Port));
409 
410 done:
411     KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
412 }
413 
414 err_t
LibTCPBind(PCONNECTION_ENDPOINT Connection,ip4_addr_t * const ipaddr,const u16_t port)415 LibTCPBind(PCONNECTION_ENDPOINT Connection, ip4_addr_t *const ipaddr, const u16_t port)
416 {
417     struct lwip_callback_msg *msg;
418     err_t ret;
419 
420     msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
421     if (msg)
422     {
423         KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
424         msg->Input.Bind.Connection = Connection;
425         msg->Input.Bind.IpAddress = ipaddr;
426         msg->Input.Bind.Port = port;
427 
428         tcpip_callback_with_block(LibTCPBindCallback, msg, 1);
429 
430         if (WaitForEventSafely(&msg->Event))
431             ret = msg->Output.Bind.Error;
432         else
433             ret = ERR_CLSD;
434 
435         ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
436 
437         return ret;
438     }
439 
440     return ERR_MEM;
441 }
442 
443 static
444 void
LibTCPListenCallback(void * arg)445 LibTCPListenCallback(void *arg)
446 {
447     struct lwip_callback_msg *msg = arg;
448 
449     ASSERT(msg);
450 
451     if (!msg->Input.Listen.Connection->SocketContext)
452     {
453         msg->Output.Listen.NewPcb = NULL;
454         goto done;
455     }
456 
457     msg->Output.Listen.NewPcb = tcp_listen_with_backlog((PTCP_PCB)msg->Input.Listen.Connection->SocketContext, msg->Input.Listen.Backlog);
458 
459     if (msg->Output.Listen.NewPcb)
460     {
461         tcp_accept(msg->Output.Listen.NewPcb, InternalAcceptEventHandler);
462     }
463 
464 done:
465     KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
466 }
467 
468 PTCP_PCB
LibTCPListen(PCONNECTION_ENDPOINT Connection,const u8_t backlog)469 LibTCPListen(PCONNECTION_ENDPOINT Connection, const u8_t backlog)
470 {
471     struct lwip_callback_msg *msg;
472     PTCP_PCB ret;
473 
474     msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
475     if (msg)
476     {
477         KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
478         msg->Input.Listen.Connection = Connection;
479         msg->Input.Listen.Backlog = backlog;
480 
481         tcpip_callback_with_block(LibTCPListenCallback, msg, 1);
482 
483         if (WaitForEventSafely(&msg->Event))
484             ret = msg->Output.Listen.NewPcb;
485         else
486             ret = NULL;
487 
488         ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
489 
490         return ret;
491     }
492 
493     return NULL;
494 }
495 
496 static
497 void
LibTCPSendCallback(void * arg)498 LibTCPSendCallback(void *arg)
499 {
500     struct lwip_callback_msg *msg = arg;
501     PTCP_PCB pcb = msg->Input.Send.Connection->SocketContext;
502     ULONG SendLength;
503     UCHAR SendFlags;
504 
505     ASSERT(msg);
506 
507     if (!msg->Input.Send.Connection->SocketContext)
508     {
509         msg->Output.Send.Error = ERR_CLSD;
510         goto done;
511     }
512 
513     if (msg->Input.Send.Connection->SendShutdown)
514     {
515         msg->Output.Send.Error = ERR_CLSD;
516         goto done;
517     }
518 
519     SendFlags = TCP_WRITE_FLAG_COPY;
520     SendLength = msg->Input.Send.DataLength;
521     if (tcp_sndbuf(pcb) == 0)
522     {
523         /* No buffer space so return pending */
524         msg->Output.Send.Error = ERR_INPROGRESS;
525         goto done;
526     }
527     else if (tcp_sndbuf(pcb) < SendLength)
528     {
529         /* We've got some room so let's send what we can */
530         SendLength = tcp_sndbuf(pcb);
531 
532         /* Don't set the push flag */
533         SendFlags |= TCP_WRITE_FLAG_MORE;
534     }
535 
536     msg->Output.Send.Error = tcp_write(pcb,
537                                        msg->Input.Send.Data,
538                                        SendLength,
539                                        SendFlags);
540     if (msg->Output.Send.Error == ERR_OK)
541     {
542         /* Queued successfully so try to send it */
543         tcp_output((PTCP_PCB)msg->Input.Send.Connection->SocketContext);
544         msg->Output.Send.Information = SendLength;
545     }
546     else if (msg->Output.Send.Error == ERR_MEM)
547     {
548         /* The queue is too long */
549         msg->Output.Send.Error = ERR_INPROGRESS;
550     }
551 
552 done:
553     KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
554 }
555 
556 err_t
LibTCPSend(PCONNECTION_ENDPOINT Connection,void * const dataptr,const u16_t len,ULONG * sent,const int safe)557 LibTCPSend(PCONNECTION_ENDPOINT Connection, void *const dataptr, const u16_t len, ULONG *sent, const int safe)
558 {
559     err_t ret;
560     struct lwip_callback_msg *msg;
561 
562     msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
563     if (msg)
564     {
565         KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
566         msg->Input.Send.Connection = Connection;
567         msg->Input.Send.Data = dataptr;
568         msg->Input.Send.DataLength = len;
569 
570         if (safe)
571             LibTCPSendCallback(msg);
572         else
573             tcpip_callback_with_block(LibTCPSendCallback, msg, 1);
574 
575         if (WaitForEventSafely(&msg->Event))
576             ret = msg->Output.Send.Error;
577         else
578             ret = ERR_CLSD;
579 
580         if (ret == ERR_OK)
581             *sent = msg->Output.Send.Information;
582         else
583             *sent = 0;
584 
585         ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
586 
587         return ret;
588     }
589 
590     return ERR_MEM;
591 }
592 
593 static
594 void
LibTCPConnectCallback(void * arg)595 LibTCPConnectCallback(void *arg)
596 {
597     struct lwip_callback_msg *msg = arg;
598     err_t Error;
599 
600     ASSERT(arg);
601 
602     if (!msg->Input.Connect.Connection->SocketContext)
603     {
604         msg->Output.Connect.Error = ERR_CLSD;
605         goto done;
606     }
607 
608     tcp_recv((PTCP_PCB)msg->Input.Connect.Connection->SocketContext, InternalRecvEventHandler);
609     tcp_sent((PTCP_PCB)msg->Input.Connect.Connection->SocketContext, InternalSendEventHandler);
610 
611     Error = tcp_connect((PTCP_PCB)msg->Input.Connect.Connection->SocketContext,
612                         msg->Input.Connect.IpAddress, lwip_ntohs(msg->Input.Connect.Port),
613                         InternalConnectEventHandler);
614 
615     msg->Output.Connect.Error = Error == ERR_OK ? ERR_INPROGRESS : Error;
616 
617 done:
618     KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
619 }
620 
621 err_t
LibTCPConnect(PCONNECTION_ENDPOINT Connection,ip_addr_t * const ipaddr,const u16_t port)622 LibTCPConnect(PCONNECTION_ENDPOINT Connection, ip_addr_t *const ipaddr, const u16_t port)
623 {
624     struct lwip_callback_msg *msg;
625     err_t ret;
626 
627     msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
628     if (msg)
629     {
630         KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
631         msg->Input.Connect.Connection = Connection;
632         msg->Input.Connect.IpAddress = ipaddr;
633         msg->Input.Connect.Port = port;
634 
635         tcpip_callback_with_block(LibTCPConnectCallback, msg, 1);
636 
637         if (WaitForEventSafely(&msg->Event))
638         {
639             ret = msg->Output.Connect.Error;
640         }
641         else
642             ret = ERR_CLSD;
643 
644         ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
645 
646         return ret;
647     }
648 
649     return ERR_MEM;
650 }
651 
652 static
653 void
LibTCPShutdownCallback(void * arg)654 LibTCPShutdownCallback(void *arg)
655 {
656     struct lwip_callback_msg *msg = arg;
657     PTCP_PCB pcb = msg->Input.Shutdown.Connection->SocketContext;
658 
659     if (!msg->Input.Shutdown.Connection->SocketContext)
660     {
661         msg->Output.Shutdown.Error = ERR_CLSD;
662         goto done;
663     }
664 
665     /* LwIP makes the (questionable) assumption that SHUTDOWN_RDWR is equivalent to tcp_close().
666      * This assumption holds even if the shutdown calls are done separately (even through multiple
667      * WinSock shutdown() calls). This assumption means that lwIP has the right to deallocate our
668      * PCB without telling us if we shutdown TX and RX. To avoid these problems, we'll clear the
669      * socket context if we have called shutdown for TX and RX.
670      */
671     if (msg->Input.Shutdown.shut_rx != msg->Input.Shutdown.shut_tx) {
672         if (msg->Input.Shutdown.shut_rx) {
673             msg->Output.Shutdown.Error = tcp_shutdown(pcb, TRUE, FALSE);
674         }
675         if (msg->Input.Shutdown.shut_tx) {
676             msg->Output.Shutdown.Error = tcp_shutdown(pcb, FALSE, TRUE);
677         }
678     }
679     else if (msg->Input.Shutdown.shut_rx) {
680         /* We received both RX and TX requests, which seems to mean closing connection from TDI.
681          * So call tcp_close, otherwise we risk to be put in TCP_WAIT_* states, which makes further
682          * attempts to close the socket to fail in this state.
683          */
684         msg->Output.Shutdown.Error = tcp_close(pcb);
685     }
686     else {
687         /* This case shouldn't happen */
688         DbgPrint("Requested socket shutdown(0, 0) !\n");
689     }
690 
691     if (!msg->Output.Shutdown.Error)
692     {
693         if (msg->Input.Shutdown.shut_rx)
694         {
695             msg->Input.Shutdown.Connection->ReceiveShutdown = TRUE;
696             msg->Input.Shutdown.Connection->ReceiveShutdownStatus = STATUS_FILE_CLOSED;
697         }
698 
699         if (msg->Input.Shutdown.shut_tx)
700             msg->Input.Shutdown.Connection->SendShutdown = TRUE;
701 
702         if (msg->Input.Shutdown.Connection->ReceiveShutdown &&
703             msg->Input.Shutdown.Connection->SendShutdown)
704         {
705             /* The PCB is not ours anymore */
706             msg->Input.Shutdown.Connection->SocketContext = NULL;
707             tcp_arg(pcb, NULL);
708             TCPFinEventHandler(msg->Input.Shutdown.Connection, ERR_CLSD);
709         }
710     }
711 
712 done:
713     KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
714 }
715 
716 err_t
LibTCPShutdown(PCONNECTION_ENDPOINT Connection,const int shut_rx,const int shut_tx)717 LibTCPShutdown(PCONNECTION_ENDPOINT Connection, const int shut_rx, const int shut_tx)
718 {
719     struct lwip_callback_msg *msg;
720     err_t ret;
721 
722     msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
723     if (msg)
724     {
725         KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
726 
727         msg->Input.Shutdown.Connection = Connection;
728         msg->Input.Shutdown.shut_rx = shut_rx;
729         msg->Input.Shutdown.shut_tx = shut_tx;
730 
731         tcpip_callback_with_block(LibTCPShutdownCallback, msg, 1);
732 
733         if (WaitForEventSafely(&msg->Event))
734             ret = msg->Output.Shutdown.Error;
735         else
736             ret = ERR_CLSD;
737 
738         ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
739 
740         return ret;
741     }
742 
743     return ERR_MEM;
744 }
745 
746 static
747 void
LibTCPCloseCallback(void * arg)748 LibTCPCloseCallback(void *arg)
749 {
750     struct lwip_callback_msg *msg = arg;
751     PTCP_PCB pcb = msg->Input.Close.Connection->SocketContext;
752 
753     /* Empty the queue even if we're already "closed" */
754     LibTCPEmptyQueue(msg->Input.Close.Connection);
755 
756     /* Check if we've already been closed */
757     if (msg->Input.Close.Connection->Closing)
758     {
759         msg->Output.Close.Error = ERR_OK;
760         goto done;
761     }
762 
763     /* Enter "closing" mode if we're doing a normal close */
764     if (msg->Input.Close.Callback)
765         msg->Input.Close.Connection->Closing = TRUE;
766 
767     /* Check if the PCB was already "closed" but the client doesn't know it yet */
768     if (!msg->Input.Close.Connection->SocketContext)
769     {
770         msg->Output.Close.Error = ERR_OK;
771         goto done;
772     }
773 
774     /* Clear the PCB pointer and stop callbacks */
775     msg->Input.Close.Connection->SocketContext = NULL;
776     tcp_arg(pcb, NULL);
777 
778     /* This may generate additional callbacks but we don't care,
779      * because they're too inconsistent to rely on */
780     msg->Output.Close.Error = tcp_close(pcb);
781 
782     if (msg->Output.Close.Error)
783     {
784         /* Restore the PCB pointer */
785         msg->Input.Close.Connection->SocketContext = pcb;
786         msg->Input.Close.Connection->Closing = FALSE;
787     }
788     else if (msg->Input.Close.Callback)
789     {
790         TCPFinEventHandler(msg->Input.Close.Connection, ERR_CLSD);
791     }
792 
793 done:
794     KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
795 }
796 
797 err_t
LibTCPClose(PCONNECTION_ENDPOINT Connection,const int safe,const int callback)798 LibTCPClose(PCONNECTION_ENDPOINT Connection, const int safe, const int callback)
799 {
800     err_t ret;
801     struct lwip_callback_msg *msg;
802 
803     msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
804     if (msg)
805     {
806         KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
807 
808         msg->Input.Close.Connection = Connection;
809         msg->Input.Close.Callback = callback;
810 
811         if (safe)
812             LibTCPCloseCallback(msg);
813         else
814             tcpip_callback_with_block(LibTCPCloseCallback, msg, 1);
815 
816         if (WaitForEventSafely(&msg->Event))
817             ret = msg->Output.Close.Error;
818         else
819             ret = ERR_CLSD;
820 
821         ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
822 
823         return ret;
824     }
825 
826     return ERR_MEM;
827 }
828 
829 void
LibTCPAccept(PTCP_PCB pcb,struct tcp_pcb * listen_pcb,void * arg)830 LibTCPAccept(PTCP_PCB pcb, struct tcp_pcb *listen_pcb, void *arg)
831 {
832     ASSERT(arg);
833 
834     tcp_arg(pcb, NULL);
835     tcp_recv(pcb, InternalRecvEventHandler);
836     tcp_sent(pcb, InternalSendEventHandler);
837     tcp_err(pcb, InternalErrorEventHandler);
838     tcp_arg(pcb, arg);
839 
840     tcp_accepted(listen_pcb);
841 }
842 
843 err_t
LibTCPGetHostName(PTCP_PCB pcb,ip_addr_t * const ipaddr,u16_t * const port)844 LibTCPGetHostName(PTCP_PCB pcb, ip_addr_t *const ipaddr, u16_t *const port)
845 {
846     if (!pcb)
847         return ERR_CLSD;
848 
849     *ipaddr = pcb->local_ip;
850     *port = pcb->local_port;
851 
852     return ERR_OK;
853 }
854 
855 err_t
LibTCPGetPeerName(PTCP_PCB pcb,ip_addr_t * const ipaddr,u16_t * const port)856 LibTCPGetPeerName(PTCP_PCB pcb, ip_addr_t * const ipaddr, u16_t * const port)
857 {
858     if (!pcb)
859         return ERR_CLSD;
860 
861     *ipaddr = pcb->remote_ip;
862     *port = pcb->remote_port;
863 
864     return ERR_OK;
865 }
866 
867 void
LibTCPSetNoDelay(PTCP_PCB pcb,BOOLEAN Set)868 LibTCPSetNoDelay(
869     PTCP_PCB pcb,
870     BOOLEAN Set)
871 {
872     if (Set)
873         pcb->flags |= TF_NODELAY;
874     else
875         pcb->flags &= ~TF_NODELAY;
876 }
877 
878 void
LibTCPGetSocketStatus(PTCP_PCB pcb,PULONG State)879 LibTCPGetSocketStatus(
880     PTCP_PCB pcb,
881     PULONG State)
882 {
883     /* Translate state from enum tcp_state -> MIB_TCP_STATE */
884     *State = pcb->state + 1;
885 }
886