xref: /reactos/ntoskrnl/lpc/send.c (revision b3c55b9e)
1 /*
2  * PROJECT:         ReactOS Kernel
3  * LICENSE:         GPL - See COPYING in the top level directory
4  * FILE:            ntoskrnl/lpc/send.c
5  * PURPOSE:         Local Procedure Call: Sending (Requests)
6  * PROGRAMMERS:     Alex Ionescu (alex.ionescu@reactos.org)
7  */
8 
9 /* INCLUDES ******************************************************************/
10 
11 #include <ntoskrnl.h>
12 #define NDEBUG
13 #include <debug.h>
14 
15 /* PUBLIC FUNCTIONS **********************************************************/
16 
17 /*
18  * @implemented
19  */
20 NTSTATUS
21 NTAPI
LpcRequestPort(IN PVOID PortObject,IN PPORT_MESSAGE LpcMessage)22 LpcRequestPort(IN PVOID PortObject,
23                IN PPORT_MESSAGE LpcMessage)
24 {
25     PLPCP_PORT_OBJECT Port = PortObject, QueuePort, ConnectionPort = NULL;
26     ULONG MessageType;
27     PLPCP_MESSAGE Message;
28     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
29     PETHREAD Thread = PsGetCurrentThread();
30 
31     PAGED_CODE();
32 
33     LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", Port, LpcMessage);
34 
35     /* Check if this is a non-datagram message */
36     if (LpcMessage->u2.s2.Type)
37     {
38         /* Get the message type */
39         MessageType = LpcpGetMessageType(LpcMessage);
40 
41         /* Validate it */
42         if ((MessageType < LPC_DATAGRAM) || (MessageType > LPC_CLIENT_DIED))
43         {
44             /* Fail */
45             return STATUS_INVALID_PARAMETER;
46         }
47 
48         /* Mark this as a kernel-mode message only if we really came from it */
49         if ((PreviousMode == KernelMode) &&
50             (LpcMessage->u2.s2.Type & LPC_KERNELMODE_MESSAGE))
51         {
52             /* We did, this is a kernel mode message */
53             MessageType |= LPC_KERNELMODE_MESSAGE;
54         }
55     }
56     else
57     {
58         /* This is a datagram */
59         MessageType = LPC_DATAGRAM;
60     }
61 
62     /* Can't have data information on this type of call */
63     if (LpcMessage->u2.s2.DataInfoOffset) return STATUS_INVALID_PARAMETER;
64 
65     /* Validate the message length */
66     if (((ULONG)LpcMessage->u1.s1.TotalLength > Port->MaxMessageLength) ||
67         ((ULONG)LpcMessage->u1.s1.TotalLength <= (ULONG)LpcMessage->u1.s1.DataLength))
68     {
69         /* Fail */
70         return STATUS_PORT_MESSAGE_TOO_LONG;
71     }
72 
73     /* Allocate a new message */
74     Message = LpcpAllocateFromPortZone();
75     if (!Message) return STATUS_NO_MEMORY;
76 
77     /* Clear the context */
78     Message->RepliedToThread = NULL;
79     Message->PortContext = NULL;
80 
81     /* Copy the message */
82     LpcpMoveMessage(&Message->Request,
83                     LpcMessage,
84                     LpcMessage + 1,
85                     MessageType,
86                     &Thread->Cid);
87 
88     /* Acquire the LPC lock */
89     KeAcquireGuardedMutex(&LpcpLock);
90 
91     /* Check if this is anything but a connection port */
92     if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
93     {
94         /* The queue port is the connected port */
95         QueuePort = Port->ConnectedPort;
96         if (QueuePort)
97         {
98             /* Check if this is a client port */
99             if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
100             {
101                 /* Then copy the context */
102                 Message->PortContext = QueuePort->PortContext;
103                 ConnectionPort = QueuePort = Port->ConnectionPort;
104                 if (!ConnectionPort)
105                 {
106                     /* Fail */
107                     LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
108                     return STATUS_PORT_DISCONNECTED;
109                 }
110             }
111             else if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT)
112             {
113                 /* Any other kind of port, use the connection port */
114                 ConnectionPort = QueuePort = Port->ConnectionPort;
115                 if (!ConnectionPort)
116                 {
117                     /* Fail */
118                     LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
119                     return STATUS_PORT_DISCONNECTED;
120                 }
121             }
122 
123             /* If we have a connection port, reference it */
124             if (ConnectionPort) ObReferenceObject(ConnectionPort);
125         }
126     }
127     else
128     {
129         /* For connection ports, use the port itself */
130         QueuePort = PortObject;
131     }
132 
133     /* Make sure we have a port */
134     if (QueuePort)
135     {
136         /* Generate the Message ID and set it */
137         Message->Request.MessageId = LpcpNextMessageId++;
138         if (!LpcpNextMessageId) LpcpNextMessageId = 1;
139         Message->Request.CallbackId = 0;
140 
141         /* No Message ID for the thread */
142         Thread->LpcReplyMessageId = 0;
143 
144         /* Insert the message in our chain */
145         InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
146 
147         /* Release the lock and the semaphore */
148         KeEnterCriticalRegion();
149         KeReleaseGuardedMutex(&LpcpLock);
150         LpcpCompleteWait(QueuePort->MsgQueue.Semaphore);
151 
152         /* If this is a waitable port, wake it up */
153         if (QueuePort->Flags & LPCP_WAITABLE_PORT)
154         {
155             /* Wake it */
156             KeSetEvent(&QueuePort->WaitEvent, IO_NO_INCREMENT, FALSE);
157         }
158 
159         KeLeaveCriticalRegion();
160 
161         /* We're done */
162         if (ConnectionPort) ObDereferenceObject(ConnectionPort);
163         LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", QueuePort, Message);
164         return STATUS_SUCCESS;
165     }
166 
167     /* If we got here, then free the message and fail */
168     LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
169     if (ConnectionPort) ObDereferenceObject(ConnectionPort);
170     return STATUS_PORT_DISCONNECTED;
171 }
172 
173 /*
174 * @implemented
175 */
176 NTSTATUS
177 NTAPI
LpcRequestWaitReplyPort(IN PVOID PortObject,IN PPORT_MESSAGE LpcRequest,OUT PPORT_MESSAGE LpcReply)178 LpcRequestWaitReplyPort(IN PVOID PortObject,
179                         IN PPORT_MESSAGE LpcRequest,
180                         OUT PPORT_MESSAGE LpcReply)
181 {
182     NTSTATUS Status = STATUS_SUCCESS;
183     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
184     PETHREAD Thread = PsGetCurrentThread();
185     PLPCP_PORT_OBJECT Port = (PLPCP_PORT_OBJECT)PortObject;
186     PLPCP_PORT_OBJECT QueuePort, ReplyPort, ConnectionPort = NULL;
187     USHORT MessageType;
188     PLPCP_MESSAGE Message;
189     BOOLEAN Callback = FALSE;
190     PKSEMAPHORE Semaphore;
191 
192     PAGED_CODE();
193 
194     LPCTRACE(LPC_SEND_DEBUG,
195              "Port: %p. Messages: %p/%p. Type: %lx\n",
196              Port,
197              LpcRequest,
198              LpcReply,
199              LpcpGetMessageType(LpcRequest));
200 
201     /* Check if the thread is dying */
202     if (Thread->LpcExitThreadCalled) return STATUS_THREAD_IS_TERMINATING;
203 
204     /* Check if this is an LPC Request */
205     MessageType = LpcpGetMessageType(LpcRequest);
206     switch (MessageType)
207     {
208         /* No type, assume LPC request */
209         case 0:
210             MessageType = LPC_REQUEST;
211             break;
212 
213         /* LPC request callback */
214         case LPC_REQUEST:
215             Callback = TRUE;
216             break;
217 
218         /* Anything else, nothing to do */
219         case LPC_CLIENT_DIED:
220         case LPC_PORT_CLOSED:
221         case LPC_EXCEPTION:
222         case LPC_DEBUG_EVENT:
223         case LPC_ERROR_EVENT:
224             break;
225 
226         /* Invalid message type */
227         default:
228             return STATUS_INVALID_PARAMETER;
229     }
230 
231     /* Set the request type */
232     LpcRequest->u2.s2.Type = MessageType;
233 
234     /* Validate the message length */
235     if (((ULONG)LpcRequest->u1.s1.TotalLength > Port->MaxMessageLength) ||
236         ((ULONG)LpcRequest->u1.s1.TotalLength <= (ULONG)LpcRequest->u1.s1.DataLength))
237     {
238         /* Fail */
239         return STATUS_PORT_MESSAGE_TOO_LONG;
240     }
241 
242     /* Allocate a message from the port zone */
243     Message = LpcpAllocateFromPortZone();
244     if (!Message)
245     {
246         /* Fail if we couldn't allocate a message */
247         return STATUS_NO_MEMORY;
248     }
249 
250     /* Check if this is a callback */
251     if (Callback)
252     {
253         /* FIXME: TODO */
254         Semaphore = NULL; // we'd use the Thread Semaphore here
255         ASSERT(FALSE);
256         return STATUS_NOT_IMPLEMENTED;
257     }
258     else
259     {
260         /* No callback, just copy the message */
261         LpcpMoveMessage(&Message->Request,
262                         LpcRequest,
263                         LpcRequest + 1,
264                         0,
265                         &Thread->Cid);
266 
267         /* Acquire the LPC lock */
268         KeAcquireGuardedMutex(&LpcpLock);
269 
270         /* Right now clear the port context */
271         Message->PortContext = NULL;
272 
273         /* Check if this is a not connection port */
274         if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
275         {
276             /* We want the connected port */
277             QueuePort = Port->ConnectedPort;
278             if (!QueuePort)
279             {
280                 /* We have no connected port, fail */
281                 LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
282                 return STATUS_PORT_DISCONNECTED;
283             }
284 
285             /* This will be the rundown port */
286             ReplyPort = QueuePort;
287 
288             /* Check if this is a communication port */
289             if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
290             {
291                 /* Copy the port context and use the connection port */
292                 Message->PortContext = QueuePort->PortContext;
293                 ConnectionPort = QueuePort = Port->ConnectionPort;
294                 if (!ConnectionPort)
295                 {
296                     /* Fail */
297                     LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
298                     return STATUS_PORT_DISCONNECTED;
299                 }
300             }
301             else if ((Port->Flags & LPCP_PORT_TYPE_MASK) !=
302                       LPCP_COMMUNICATION_PORT)
303             {
304                 /* Use the connection port for anything but communication ports */
305                 ConnectionPort = QueuePort = Port->ConnectionPort;
306                 if (!ConnectionPort)
307                 {
308                     /* Fail */
309                     LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
310                     return STATUS_PORT_DISCONNECTED;
311                 }
312             }
313 
314             /* Reference the connection port if it exists */
315             if (ConnectionPort) ObReferenceObject(ConnectionPort);
316         }
317         else
318         {
319             /* Otherwise, for a connection port, use the same port object */
320             QueuePort = ReplyPort = Port;
321         }
322 
323         /* No reply thread */
324         Message->RepliedToThread = NULL;
325         Message->SenderPort = Port;
326 
327         /* Generate the Message ID and set it */
328         Message->Request.MessageId = LpcpNextMessageId++;
329         if (!LpcpNextMessageId) LpcpNextMessageId = 1;
330         Message->Request.CallbackId = 0;
331 
332         /* Set the message ID for our thread now */
333         Thread->LpcReplyMessageId = Message->Request.MessageId;
334         Thread->LpcReplyMessage = NULL;
335 
336         /* Insert the message in our chain */
337         InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
338         InsertTailList(&ReplyPort->LpcReplyChainHead, &Thread->LpcReplyChain);
339         LpcpSetPortToThread(Thread, Port);
340 
341         /* Release the lock and get the semaphore we'll use later */
342         KeEnterCriticalRegion();
343         KeReleaseGuardedMutex(&LpcpLock);
344         Semaphore = QueuePort->MsgQueue.Semaphore;
345 
346         /* If this is a waitable port, wake it up */
347         if (QueuePort->Flags & LPCP_WAITABLE_PORT)
348         {
349             /* Wake it */
350             KeSetEvent(&QueuePort->WaitEvent, IO_NO_INCREMENT, FALSE);
351         }
352     }
353 
354     /* Now release the semaphore */
355     LpcpCompleteWait(Semaphore);
356     KeLeaveCriticalRegion();
357 
358     /* And let's wait for the reply */
359     LpcpReplyWait(&Thread->LpcReplySemaphore, PreviousMode);
360 
361     /* Acquire the LPC lock */
362     KeAcquireGuardedMutex(&LpcpLock);
363 
364     /* Get the LPC Message and clear our thread's reply data */
365     Message = LpcpGetMessageFromThread(Thread);
366     Thread->LpcReplyMessage = NULL;
367     Thread->LpcReplyMessageId = 0;
368 
369     /* Check if we have anything on the reply chain*/
370     if (!IsListEmpty(&Thread->LpcReplyChain))
371     {
372         /* Remove this thread and reinitialize the list */
373         RemoveEntryList(&Thread->LpcReplyChain);
374         InitializeListHead(&Thread->LpcReplyChain);
375     }
376 
377     /* Release the lock */
378     KeReleaseGuardedMutex(&LpcpLock);
379 
380     /* Check if we got a reply */
381     if (Status == STATUS_SUCCESS)
382     {
383         /* Check if we have a valid message */
384         if (Message)
385         {
386             LPCTRACE(LPC_SEND_DEBUG,
387                      "Reply Messages: %p/%p\n",
388                      &Message->Request,
389                      (&Message->Request) + 1);
390 
391             /* Move the message */
392             LpcpMoveMessage(LpcReply,
393                             &Message->Request,
394                             (&Message->Request) + 1,
395                             0,
396                             NULL);
397 
398             /* Acquire the lock */
399             KeAcquireGuardedMutex(&LpcpLock);
400 
401             /* Check if we replied to a thread */
402             if (Message->RepliedToThread)
403             {
404                 /* Dereference */
405                 ObDereferenceObject(Message->RepliedToThread);
406                 Message->RepliedToThread = NULL;
407             }
408 
409             /* Free the message */
410             LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
411         }
412         else
413         {
414             /* We don't have a reply */
415             Status = STATUS_LPC_REPLY_LOST;
416         }
417     }
418     else
419     {
420         /* The wait failed, free the message */
421         if (Message) LpcpFreeToPortZone(Message, 0);
422     }
423 
424     /* All done */
425     LPCTRACE(LPC_SEND_DEBUG,
426              "Port: %p. Status: %d\n",
427              Port,
428              Status);
429 
430     /* Dereference the connection port */
431     if (ConnectionPort) ObDereferenceObject(ConnectionPort);
432     return Status;
433 }
434 
435 /*
436  * @implemented
437  */
438 NTSTATUS
439 NTAPI
NtRequestPort(IN HANDLE PortHandle,IN PPORT_MESSAGE LpcRequest)440 NtRequestPort(IN HANDLE PortHandle,
441               IN PPORT_MESSAGE LpcRequest)
442 {
443     NTSTATUS Status;
444     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
445     PETHREAD Thread = PsGetCurrentThread();
446     PORT_MESSAGE CapturedLpcRequest;
447     PLPCP_PORT_OBJECT Port, QueuePort, ConnectionPort = NULL;
448     ULONG MessageType;
449     PLPCP_MESSAGE Message;
450 
451     PAGED_CODE();
452 
453     /* Check if the call comes from user mode */
454     if (PreviousMode != KernelMode)
455     {
456         _SEH2_TRY
457         {
458             /* Probe and capture the LpcRequest */
459             ProbeForRead(LpcRequest, sizeof(*LpcRequest), sizeof(ULONG));
460             CapturedLpcRequest = *(volatile PORT_MESSAGE*)LpcRequest;
461         }
462         _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
463         {
464             _SEH2_YIELD(return _SEH2_GetExceptionCode());
465         }
466         _SEH2_END;
467     }
468     else
469     {
470         /* Access the LpcRequest directly */
471         CapturedLpcRequest = *LpcRequest;
472     }
473 
474     LPCTRACE(LPC_SEND_DEBUG,
475              "Handle: %p. Message: %p. Type: %lx\n",
476              PortHandle,
477              LpcRequest,
478              LpcpGetMessageType(&CapturedLpcRequest));
479 
480     /* Get the message type */
481     MessageType = CapturedLpcRequest.u2.s2.Type | LPC_DATAGRAM;
482 
483     /* Can't have data information on this type of call */
484     if (CapturedLpcRequest.u2.s2.DataInfoOffset) return STATUS_INVALID_PARAMETER;
485 
486     /* Validate the length */
487     if (((ULONG)CapturedLpcRequest.u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
488          (ULONG)CapturedLpcRequest.u1.s1.TotalLength)
489     {
490         /* Fail */
491         return STATUS_INVALID_PARAMETER;
492     }
493 
494     /* Reference the object */
495     Status = ObReferenceObjectByHandle(PortHandle,
496                                        0,
497                                        LpcPortObjectType,
498                                        PreviousMode,
499                                        (PVOID*)&Port,
500                                        NULL);
501     if (!NT_SUCCESS(Status)) return Status;
502 
503     /* Validate the message length */
504     if (((ULONG)CapturedLpcRequest.u1.s1.TotalLength > Port->MaxMessageLength) ||
505         ((ULONG)CapturedLpcRequest.u1.s1.TotalLength <= (ULONG)CapturedLpcRequest.u1.s1.DataLength))
506     {
507         /* Fail */
508         ObDereferenceObject(Port);
509         return STATUS_PORT_MESSAGE_TOO_LONG;
510     }
511 
512     /* Allocate a message from the port zone */
513     Message = LpcpAllocateFromPortZone();
514     if (!Message)
515     {
516         /* Fail if we couldn't allocate a message */
517         ObDereferenceObject(Port);
518         return STATUS_NO_MEMORY;
519     }
520 
521     /* No callback, just copy the message */
522     _SEH2_TRY
523     {
524         /* Copy it */
525         LpcpMoveMessage(&Message->Request,
526                         &CapturedLpcRequest,
527                         LpcRequest + 1,
528                         MessageType,
529                         &Thread->Cid);
530     }
531     _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
532     {
533         /* Cleanup and return the exception code */
534         LpcpFreeToPortZone(Message, 0);
535         ObDereferenceObject(Port);
536         _SEH2_YIELD(return _SEH2_GetExceptionCode());
537     }
538     _SEH2_END;
539 
540     /* Acquire the LPC lock */
541     KeAcquireGuardedMutex(&LpcpLock);
542 
543     /* Right now clear the port context */
544     Message->PortContext = NULL;
545 
546     /* Check if this is a not connection port */
547     if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
548     {
549         /* We want the connected port */
550         QueuePort = Port->ConnectedPort;
551         if (!QueuePort)
552         {
553             /* We have no connected port, fail */
554             LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
555             ObDereferenceObject(Port);
556             return STATUS_PORT_DISCONNECTED;
557         }
558 
559         /* Check if this is a communication port */
560         if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
561         {
562             /* Copy the port context and use the connection port */
563             Message->PortContext = QueuePort->PortContext;
564             ConnectionPort = QueuePort = Port->ConnectionPort;
565             if (!ConnectionPort)
566             {
567                 /* Fail */
568                 LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
569                 ObDereferenceObject(Port);
570                 return STATUS_PORT_DISCONNECTED;
571             }
572         }
573         else if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT)
574         {
575             /* Use the connection port for anything but communication ports */
576             ConnectionPort = QueuePort = Port->ConnectionPort;
577             if (!ConnectionPort)
578             {
579                 /* Fail */
580                 LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
581                 ObDereferenceObject(Port);
582                 return STATUS_PORT_DISCONNECTED;
583             }
584         }
585 
586         /* Reference the connection port if it exists */
587         if (ConnectionPort) ObReferenceObject(ConnectionPort);
588     }
589     else
590     {
591         /* Otherwise, for a connection port, use the same port object */
592         QueuePort = Port;
593     }
594 
595     /* Reference QueuePort if we have it */
596     if (QueuePort && ObReferenceObjectSafe(QueuePort))
597     {
598         /* Set sender's port */
599         Message->SenderPort = Port;
600 
601         /* Generate the Message ID and set it */
602         Message->Request.MessageId = LpcpNextMessageId++;
603         if (!LpcpNextMessageId) LpcpNextMessageId = 1;
604         Message->Request.CallbackId = 0;
605 
606         /* No Message ID for the thread */
607         Thread->LpcReplyMessageId = 0;
608 
609         /* Insert the message in our chain */
610         InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
611 
612         /* Release the lock and the semaphore */
613         KeEnterCriticalRegion();
614         KeReleaseGuardedMutex(&LpcpLock);
615         LpcpCompleteWait(QueuePort->MsgQueue.Semaphore);
616 
617         /* If this is a waitable port, wake it up */
618         if (QueuePort->Flags & LPCP_WAITABLE_PORT)
619         {
620             /* Wake it */
621             KeSetEvent(&QueuePort->WaitEvent, IO_NO_INCREMENT, FALSE);
622         }
623 
624         KeLeaveCriticalRegion();
625 
626         /* Dereference objects */
627         if (ConnectionPort) ObDereferenceObject(ConnectionPort);
628         ObDereferenceObject(QueuePort);
629         ObDereferenceObject(Port);
630         LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", QueuePort, Message);
631         return STATUS_SUCCESS;
632     }
633 
634     Status = STATUS_PORT_DISCONNECTED;
635 
636     /* All done with a failure*/
637     LPCTRACE(LPC_SEND_DEBUG,
638              "Port: %p. Status: %d\n",
639              Port,
640              Status);
641 
642     /* The wait failed, free the message */
643     if (Message) LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
644 
645     ObDereferenceObject(Port);
646     if (ConnectionPort) ObDereferenceObject(ConnectionPort);
647     return Status;
648 }
649 
650 NTSTATUS
651 NTAPI
LpcpVerifyMessageDataInfo(_In_ PPORT_MESSAGE Message,_Out_ PULONG NumberOfDataEntries)652 LpcpVerifyMessageDataInfo(
653     _In_ PPORT_MESSAGE Message,
654     _Out_ PULONG NumberOfDataEntries)
655 {
656     PLPCP_DATA_INFO DataInfo;
657     PUCHAR EndOfEntries;
658 
659     /* Check if we have no data info at all */
660     if (Message->u2.s2.DataInfoOffset == 0)
661     {
662         *NumberOfDataEntries = 0;
663         return STATUS_SUCCESS;
664     }
665 
666     /* Make sure the data info structure is within the message */
667     if (((ULONG)Message->u1.s1.TotalLength <
668             sizeof(PORT_MESSAGE) + sizeof(LPCP_DATA_INFO)) ||
669         ((ULONG)Message->u2.s2.DataInfoOffset < sizeof(PORT_MESSAGE)) ||
670         ((ULONG)Message->u2.s2.DataInfoOffset >
671             ((ULONG)Message->u1.s1.TotalLength - sizeof(LPCP_DATA_INFO))))
672     {
673         return STATUS_INVALID_PARAMETER;
674     }
675 
676     /* Get a pointer to the data info */
677     DataInfo = LpcpGetDataInfoFromMessage(Message);
678 
679     /* Make sure the full data info with all entries is within the message */
680     EndOfEntries = (PUCHAR)&DataInfo->Entries[DataInfo->NumberOfEntries];
681     if ((EndOfEntries > ((PUCHAR)Message + (ULONG)Message->u1.s1.TotalLength)) ||
682         (EndOfEntries < (PUCHAR)Message))
683     {
684         return STATUS_INVALID_PARAMETER;
685     }
686 
687     *NumberOfDataEntries = DataInfo->NumberOfEntries;
688     return STATUS_SUCCESS;
689 }
690 
691 /*
692  * @implemented
693  */
694 NTSTATUS
695 NTAPI
NtRequestWaitReplyPort(IN HANDLE PortHandle,IN PPORT_MESSAGE LpcRequest,IN OUT PPORT_MESSAGE LpcReply)696 NtRequestWaitReplyPort(IN HANDLE PortHandle,
697                        IN PPORT_MESSAGE LpcRequest,
698                        IN OUT PPORT_MESSAGE LpcReply)
699 {
700     NTSTATUS Status;
701     PORT_MESSAGE CapturedLpcRequest;
702     ULONG NumberOfDataEntries;
703     PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort, ConnectionPort = NULL;
704     PLPCP_MESSAGE Message;
705     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
706     PETHREAD Thread = PsGetCurrentThread();
707     BOOLEAN Callback;
708     PKSEMAPHORE Semaphore;
709     ULONG MessageType;
710     PLPCP_DATA_INFO DataInfo;
711 
712     PAGED_CODE();
713 
714     /* Check if the thread is dying */
715     if (Thread->LpcExitThreadCalled)
716         return STATUS_THREAD_IS_TERMINATING;
717 
718     /* Check for user mode access */
719     if (PreviousMode != KernelMode)
720     {
721         _SEH2_TRY
722         {
723             /* Probe and capture the LpcRequest */
724             ProbeForRead(LpcRequest, sizeof(*LpcRequest), sizeof(ULONG));
725             CapturedLpcRequest = *(volatile PORT_MESSAGE*)LpcRequest;
726 
727             /* Probe the reply message for write */
728             ProbeForWrite(LpcReply, sizeof(*LpcReply), sizeof(ULONG));
729 
730             /* Make sure the data entries in the request message are valid */
731             Status = LpcpVerifyMessageDataInfo(LpcRequest, &NumberOfDataEntries);
732             if (!NT_SUCCESS(Status))
733             {
734                 DPRINT1("LpcpVerifyMessageDataInfo failed\n");
735                 _SEH2_YIELD(return Status);
736             }
737         }
738         _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
739         {
740             DPRINT1("Got exception\n");
741             _SEH2_YIELD(return _SEH2_GetExceptionCode());
742         }
743         _SEH2_END;
744     }
745     else
746     {
747         CapturedLpcRequest = *LpcRequest;
748         Status = LpcpVerifyMessageDataInfo(LpcRequest, &NumberOfDataEntries);
749         if (!NT_SUCCESS(Status))
750         {
751             DPRINT1("LpcpVerifyMessageDataInfo failed\n");
752             return Status;
753         }
754     }
755 
756     LPCTRACE(LPC_SEND_DEBUG,
757              "Handle: %p. Messages: %p/%p. Type: %lx\n",
758              PortHandle,
759              LpcRequest,
760              LpcReply,
761              LpcpGetMessageType(&CapturedLpcRequest));
762 
763     /* This flag is undocumented. Remove it before continuing */
764     CapturedLpcRequest.u2.s2.Type &= ~0x4000;
765 
766     /* Check if this is an LPC Request */
767     if (LpcpGetMessageType(&CapturedLpcRequest) == LPC_REQUEST)
768     {
769         /* Then it's a callback */
770         Callback = TRUE;
771     }
772     else if (LpcpGetMessageType(&CapturedLpcRequest))
773     {
774         /* This is a not kernel-mode message */
775         DPRINT1("Not a kernel-mode message!\n");
776         return STATUS_INVALID_PARAMETER;
777     }
778     else
779     {
780         /* This is a kernel-mode message without a callback */
781         CapturedLpcRequest.u2.s2.Type |= LPC_REQUEST;
782         Callback = FALSE;
783     }
784 
785     /* Get the message type */
786     MessageType = CapturedLpcRequest.u2.s2.Type;
787 
788     /* Due to the above probe, we know that TotalLength is positive */
789     ASSERT(CapturedLpcRequest.u1.s1.TotalLength >= 0);
790 
791     /* Validate the length */
792     if ((((ULONG)(USHORT)CapturedLpcRequest.u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
793          (ULONG)CapturedLpcRequest.u1.s1.TotalLength))
794     {
795         /* Fail */
796         DPRINT1("Invalid message length: %u, %u\n",
797                 CapturedLpcRequest.u1.s1.DataLength,
798                 CapturedLpcRequest.u1.s1.TotalLength);
799         return STATUS_INVALID_PARAMETER;
800     }
801 
802     /* Reference the object */
803     Status = ObReferenceObjectByHandle(PortHandle,
804                                        0,
805                                        LpcPortObjectType,
806                                        PreviousMode,
807                                        (PVOID*)&Port,
808                                        NULL);
809     if (!NT_SUCCESS(Status)) return Status;
810 
811     /* Validate the message length */
812     if (((ULONG)CapturedLpcRequest.u1.s1.TotalLength > Port->MaxMessageLength) ||
813         ((ULONG)CapturedLpcRequest.u1.s1.TotalLength <= (ULONG)CapturedLpcRequest.u1.s1.DataLength))
814     {
815         /* Fail */
816         DPRINT1("Invalid message length: %u, %u\n",
817                 CapturedLpcRequest.u1.s1.DataLength,
818                 CapturedLpcRequest.u1.s1.TotalLength);
819         ObDereferenceObject(Port);
820         return STATUS_PORT_MESSAGE_TOO_LONG;
821     }
822 
823     /* Allocate a message from the port zone */
824     Message = LpcpAllocateFromPortZone();
825     if (!Message)
826     {
827         /* Fail if we couldn't allocate a message */
828         DPRINT1("Failed to allocate a message!\n");
829         ObDereferenceObject(Port);
830         return STATUS_NO_MEMORY;
831     }
832 
833     /* Check if this is a callback */
834     if (Callback)
835     {
836         /* FIXME: TODO */
837         Semaphore = NULL; // we'd use the Thread Semaphore here
838         ASSERT(FALSE);
839     }
840     else
841     {
842         /* No callback, just copy the message */
843         _SEH2_TRY
844         {
845             /* Check if we have data info entries */
846             if (LpcRequest->u2.s2.DataInfoOffset != 0)
847             {
848                 /* Get the data info and check if the number of entries matches
849                    what we expect */
850                 DataInfo = LpcpGetDataInfoFromMessage(LpcRequest);
851                 if (DataInfo->NumberOfEntries != NumberOfDataEntries)
852                 {
853                     LpcpFreeToPortZone(Message, 0);
854                     ObDereferenceObject(Port);
855                     DPRINT1("NumberOfEntries has changed: %u, %u\n",
856                             DataInfo->NumberOfEntries, NumberOfDataEntries);
857                     _SEH2_YIELD(return STATUS_INVALID_PARAMETER);
858                 }
859             }
860 
861             /* Copy it */
862             LpcpMoveMessage(&Message->Request,
863                             &CapturedLpcRequest,
864                             LpcRequest + 1,
865                             MessageType,
866                             &Thread->Cid);
867         }
868         _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
869         {
870             /* Cleanup and return the exception code */
871             DPRINT1("Got exception!\n");
872             LpcpFreeToPortZone(Message, 0);
873             ObDereferenceObject(Port);
874             _SEH2_YIELD(return _SEH2_GetExceptionCode());
875         }
876         _SEH2_END;
877 
878         /* Acquire the LPC lock */
879         KeAcquireGuardedMutex(&LpcpLock);
880 
881         /* Right now clear the port context */
882         Message->PortContext = NULL;
883 
884         /* Check if this is a not connection port */
885         if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
886         {
887             /* We want the connected port */
888             QueuePort = Port->ConnectedPort;
889             if (!QueuePort)
890             {
891                 /* We have no connected port, fail */
892                 DPRINT1("No connected port\n");
893                 LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
894                 ObDereferenceObject(Port);
895                 return STATUS_PORT_DISCONNECTED;
896             }
897 
898             /* This will be the rundown port */
899             ReplyPort = QueuePort;
900 
901             /* Check if this is a client port */
902             if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
903             {
904                 /* Copy the port context */
905                 Message->PortContext = QueuePort->PortContext;
906             }
907 
908             if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT)
909             {
910                 /* Use the connection port for anything but communication ports */
911                 ConnectionPort = QueuePort = Port->ConnectionPort;
912                 if (!ConnectionPort)
913                 {
914                     /* Fail */
915                     DPRINT1("No connection port\n");
916                     LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
917                     ObDereferenceObject(Port);
918                     return STATUS_PORT_DISCONNECTED;
919                 }
920             }
921 
922             /* Reference the connection port if it exists */
923             if (ConnectionPort) ObReferenceObject(ConnectionPort);
924         }
925         else
926         {
927             /* Otherwise, for a connection port, use the same port object */
928             QueuePort = ReplyPort = Port;
929         }
930 
931         /* No reply thread */
932         Message->RepliedToThread = NULL;
933         Message->SenderPort = Port;
934 
935         /* Generate the Message ID and set it */
936         Message->Request.MessageId = LpcpNextMessageId++;
937         if (!LpcpNextMessageId) LpcpNextMessageId = 1;
938         Message->Request.CallbackId = 0;
939 
940         /* Set the message ID for our thread now */
941         Thread->LpcReplyMessageId = Message->Request.MessageId;
942         Thread->LpcReplyMessage = NULL;
943 
944         /* Insert the message in our chain */
945         InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
946         InsertTailList(&ReplyPort->LpcReplyChainHead, &Thread->LpcReplyChain);
947         LpcpSetPortToThread(Thread, Port);
948 
949         /* Release the lock and get the semaphore we'll use later */
950         KeEnterCriticalRegion();
951         KeReleaseGuardedMutex(&LpcpLock);
952         Semaphore = QueuePort->MsgQueue.Semaphore;
953 
954         /* If this is a waitable port, wake it up */
955         if (QueuePort->Flags & LPCP_WAITABLE_PORT)
956         {
957             /* Wake it */
958             KeSetEvent(&QueuePort->WaitEvent, IO_NO_INCREMENT, FALSE);
959         }
960     }
961 
962     /* Now release the semaphore */
963     LpcpCompleteWait(Semaphore);
964     KeLeaveCriticalRegion();
965 
966     /* And let's wait for the reply */
967     LpcpReplyWait(&Thread->LpcReplySemaphore, PreviousMode);
968 
969     /* Acquire the LPC lock */
970     KeAcquireGuardedMutex(&LpcpLock);
971 
972     /* Get the LPC Message and clear our thread's reply data */
973     Message = LpcpGetMessageFromThread(Thread);
974     Thread->LpcReplyMessage = NULL;
975     Thread->LpcReplyMessageId = 0;
976 
977     /* Check if we have anything on the reply chain*/
978     if (!IsListEmpty(&Thread->LpcReplyChain))
979     {
980         /* Remove this thread and reinitialize the list */
981         RemoveEntryList(&Thread->LpcReplyChain);
982         InitializeListHead(&Thread->LpcReplyChain);
983     }
984 
985     /* Release the lock */
986     KeReleaseGuardedMutex(&LpcpLock);
987 
988     /* Check if we got a reply */
989     if (Status == STATUS_SUCCESS)
990     {
991         /* Check if we have a valid message */
992         if (Message)
993         {
994             LPCTRACE(LPC_SEND_DEBUG,
995                      "Reply Messages: %p/%p\n",
996                      &Message->Request,
997                      (&Message->Request) + 1);
998 
999             /* Move the message */
1000             _SEH2_TRY
1001             {
1002                 LpcpMoveMessage(LpcReply,
1003                                 &Message->Request,
1004                                 (&Message->Request) + 1,
1005                                 0,
1006                                 NULL);
1007             }
1008             _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
1009             {
1010                 DPRINT1("Got exception!\n");
1011                 Status = _SEH2_GetExceptionCode();
1012             }
1013             _SEH2_END;
1014 
1015             /* Check if this is an LPC request with data information */
1016             if ((LpcpGetMessageType(&Message->Request) == LPC_REQUEST) &&
1017                 (Message->Request.u2.s2.DataInfoOffset))
1018             {
1019                 /* Save the data information */
1020                 LpcpSaveDataInfoMessage(Port, Message, 0);
1021             }
1022             else
1023             {
1024                 /* Otherwise, just free it */
1025                 LpcpFreeToPortZone(Message, 0);
1026             }
1027         }
1028         else
1029         {
1030             /* We don't have a reply */
1031             Status = STATUS_LPC_REPLY_LOST;
1032         }
1033     }
1034     else
1035     {
1036         /* The wait failed, free the message */
1037         if (Message) LpcpFreeToPortZone(Message, 0);
1038     }
1039 
1040     /* All done */
1041     LPCTRACE(LPC_SEND_DEBUG,
1042              "Port: %p. Status: %d\n",
1043              Port,
1044              Status);
1045     ObDereferenceObject(Port);
1046     if (ConnectionPort) ObDereferenceObject(ConnectionPort);
1047     return Status;
1048 }
1049 
1050 /* EOF */
1051