xref: /reactos/ntoskrnl/lpc/complete.c (revision 1734f297)
1 /*
2 * PROJECT:         ReactOS Kernel
3 * LICENSE:         GPL - See COPYING in the top level directory
4 * FILE:            ntoskrnl/lpc/complete.c
5 * PURPOSE:         Local Procedure Call: Connection Completion
6 * PROGRAMMERS:     Alex Ionescu (alex.ionescu@reactos.org)
7 */
8 
9 /* INCLUDES ******************************************************************/
10 
11 #include <ntoskrnl.h>
12 #define NDEBUG
13 #include <debug.h>
14 
15 /* PRIVATE FUNCTIONS *********************************************************/
16 
17 VOID
18 NTAPI
19 LpcpPrepareToWakeClient(IN PETHREAD Thread)
20 {
21     PAGED_CODE();
22 
23     /* Make sure the thread isn't dying and it has a valid chain */
24     if (!(Thread->LpcExitThreadCalled) &&
25         !(IsListEmpty(&Thread->LpcReplyChain)))
26     {
27         /* Remove it from the list and reinitialize it */
28         RemoveEntryList(&Thread->LpcReplyChain);
29         InitializeListHead(&Thread->LpcReplyChain);
30     }
31 }
32 
33 /* PUBLIC FUNCTIONS **********************************************************/
34 
35 /*
36  * @implemented
37  */
38 NTSTATUS
39 NTAPI
40 NtAcceptConnectPort(OUT PHANDLE PortHandle,
41                     IN PVOID PortContext OPTIONAL,
42                     IN PPORT_MESSAGE ReplyMessage,
43                     IN BOOLEAN AcceptConnection,
44                     IN OUT PPORT_VIEW ServerView OPTIONAL,
45                     OUT PREMOTE_PORT_VIEW ClientView OPTIONAL)
46 {
47     NTSTATUS Status;
48     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
49     PORT_VIEW CapturedServerView;
50     PORT_MESSAGE CapturedReplyMessage;
51     ULONG ConnectionInfoLength;
52     PLPCP_PORT_OBJECT ConnectionPort, ServerPort, ClientPort;
53     PLPCP_CONNECTION_MESSAGE ConnectMessage;
54     PLPCP_MESSAGE Message;
55     PVOID ClientSectionToMap = NULL;
56     HANDLE Handle;
57     PEPROCESS ClientProcess;
58     PETHREAD ClientThread;
59     LARGE_INTEGER SectionOffset;
60 
61     PAGED_CODE();
62     LPCTRACE(LPC_COMPLETE_DEBUG,
63              "Context: %p. Message: %p. Accept: %lx. Views: %p/%p\n",
64              PortContext,
65              ReplyMessage,
66              AcceptConnection,
67              ClientView,
68              ServerView);
69 
70     /* Check if the call comes from user mode */
71     if (PreviousMode != KernelMode)
72     {
73         _SEH2_TRY
74         {
75             /* Probe the PortHandle */
76             ProbeForWriteHandle(PortHandle);
77 
78             /* Probe the basic ReplyMessage structure */
79             ProbeForRead(ReplyMessage, sizeof(*ReplyMessage), sizeof(ULONG));
80             CapturedReplyMessage = *(volatile PORT_MESSAGE*)ReplyMessage;
81             ConnectionInfoLength = CapturedReplyMessage.u1.s1.DataLength;
82 
83             /* Probe the connection info */
84             ProbeForRead(ReplyMessage + 1, ConnectionInfoLength, 1);
85 
86             /* The following parameters are optional */
87 
88             /* Capture the server view */
89             if (ServerView)
90             {
91                 ProbeForWrite(ServerView, sizeof(*ServerView), sizeof(ULONG));
92                 CapturedServerView = *(volatile PORT_VIEW*)ServerView;
93 
94                 /* Validate the size of the server view */
95                 if (CapturedServerView.Length != sizeof(CapturedServerView))
96                 {
97                     /* Invalid size */
98                     _SEH2_YIELD(return STATUS_INVALID_PARAMETER);
99                 }
100             }
101 
102             /* Capture the client view */
103             if (ClientView)
104             {
105                 ProbeForWrite(ClientView, sizeof(*ClientView), sizeof(ULONG));
106 
107                 /* Validate the size of the client view */
108                 if (((volatile REMOTE_PORT_VIEW*)ClientView)->Length != sizeof(*ClientView))
109                 {
110                     /* Invalid size */
111                     _SEH2_YIELD(return STATUS_INVALID_PARAMETER);
112                 }
113             }
114         }
115         _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
116         {
117             /* There was an exception, return the exception code */
118             _SEH2_YIELD(return _SEH2_GetExceptionCode());
119         }
120         _SEH2_END;
121     }
122     else
123     {
124         CapturedReplyMessage = *ReplyMessage;
125         ConnectionInfoLength = CapturedReplyMessage.u1.s1.DataLength;
126 
127         /* Capture the server view */
128         if (ServerView)
129         {
130             /* Validate the size of the server view */
131             if (ServerView->Length != sizeof(*ServerView))
132             {
133                 /* Invalid size */
134                 return STATUS_INVALID_PARAMETER;
135             }
136             CapturedServerView = *ServerView;
137         }
138 
139         /* Capture the client view */
140         if (ClientView)
141         {
142             /* Validate the size of the client view */
143             if (ClientView->Length != sizeof(*ClientView))
144             {
145                 /* Invalid size */
146                 return STATUS_INVALID_PARAMETER;
147             }
148         }
149     }
150 
151     /* Get the client process and thread */
152     Status = PsLookupProcessThreadByCid(&CapturedReplyMessage.ClientId,
153                                         &ClientProcess,
154                                         &ClientThread);
155     if (!NT_SUCCESS(Status)) return Status;
156 
157     /* Acquire the LPC Lock */
158     KeAcquireGuardedMutex(&LpcpLock);
159 
160     /* Make sure that the client wants a reply, and this is the right one */
161     if (!(LpcpGetMessageFromThread(ClientThread)) ||
162         !(CapturedReplyMessage.MessageId) ||
163         (ClientThread->LpcReplyMessageId != CapturedReplyMessage.MessageId))
164     {
165         /* Not the reply asked for, or no reply wanted, fail */
166         KeReleaseGuardedMutex(&LpcpLock);
167         ObDereferenceObject(ClientProcess);
168         ObDereferenceObject(ClientThread);
169         return STATUS_REPLY_MESSAGE_MISMATCH;
170     }
171 
172     /* Now get the message and connection message */
173     Message = LpcpGetMessageFromThread(ClientThread);
174     ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
175 
176     /* Get the client and connection port as well */
177     ClientPort = ConnectMessage->ClientPort;
178     ConnectionPort = ClientPort->ConnectionPort;
179 
180     /* Make sure that the reply is being sent to the proper server process */
181     if (ConnectionPort->ServerProcess != PsGetCurrentProcess())
182     {
183         /* It's not, so fail */
184         KeReleaseGuardedMutex(&LpcpLock);
185         ObDereferenceObject(ClientProcess);
186         ObDereferenceObject(ClientThread);
187         return STATUS_REPLY_MESSAGE_MISMATCH;
188     }
189 
190     /* At this point, don't let other accept attempts happen */
191     ClientThread->LpcReplyMessage = NULL;
192     ClientThread->LpcReplyMessageId = 0;
193 
194     /* Clear the client port for now as well, then release the lock */
195     ConnectMessage->ClientPort = NULL;
196     KeReleaseGuardedMutex(&LpcpLock);
197 
198     /* Check the connection information length */
199     if (ConnectionInfoLength > ConnectionPort->MaxConnectionInfoLength)
200     {
201         /* Normalize it since it's too large */
202         ConnectionInfoLength = ConnectionPort->MaxConnectionInfoLength;
203     }
204 
205     /* Set the sizes of our reply message */
206     Message->Request.u1.s1.DataLength = (CSHORT)ConnectionInfoLength +
207                                          sizeof(LPCP_CONNECTION_MESSAGE);
208     Message->Request.u1.s1.TotalLength = sizeof(LPCP_MESSAGE) +
209                                      Message->Request.u1.s1.DataLength;
210 
211     /* Setup the reply message */
212     Message->Request.u2.s2.Type = LPC_REPLY;
213     Message->Request.u2.s2.DataInfoOffset = 0;
214     Message->Request.ClientId  = CapturedReplyMessage.ClientId;
215     Message->Request.MessageId = CapturedReplyMessage.MessageId;
216     Message->Request.ClientViewSize = 0;
217 
218     _SEH2_TRY
219     {
220         RtlCopyMemory(ConnectMessage + 1, ReplyMessage + 1, ConnectionInfoLength);
221     }
222     _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
223     {
224         Status = _SEH2_GetExceptionCode();
225         _SEH2_YIELD(goto Cleanup);
226     }
227     _SEH2_END;
228 
229     /* At this point, if the caller refused the connection, go to cleanup */
230     if (!AcceptConnection)
231     {
232         DPRINT1("LPC connection was refused\n");
233         goto Cleanup;
234     }
235 
236     /* Otherwise, create the actual port */
237     Status = ObCreateObject(PreviousMode,
238                             LpcPortObjectType,
239                             NULL,
240                             PreviousMode,
241                             NULL,
242                             sizeof(LPCP_PORT_OBJECT),
243                             0,
244                             0,
245                             (PVOID*)&ServerPort);
246     if (!NT_SUCCESS(Status)) goto Cleanup;
247 
248     /* Set it up */
249     RtlZeroMemory(ServerPort, sizeof(LPCP_PORT_OBJECT));
250     ServerPort->PortContext = PortContext;
251     ServerPort->Flags = LPCP_COMMUNICATION_PORT;
252     ServerPort->MaxMessageLength = ConnectionPort->MaxMessageLength;
253     InitializeListHead(&ServerPort->LpcReplyChainHead);
254     InitializeListHead(&ServerPort->LpcDataInfoChainHead);
255 
256     /* Reference the connection port until we're fully setup */
257     ObReferenceObject(ConnectionPort);
258 
259     /* Link the ports together */
260     ServerPort->ConnectionPort = ConnectionPort;
261     ServerPort->ConnectedPort = ClientPort;
262     ClientPort->ConnectedPort = ServerPort;
263 
264     /* Also set the creator CID */
265     ServerPort->Creator = PsGetCurrentThread()->Cid;
266     ClientPort->Creator = Message->Request.ClientId;
267 
268     /* Get the section associated and then clear it, while inside the lock */
269     KeAcquireGuardedMutex(&LpcpLock);
270     ClientSectionToMap = ConnectMessage->SectionToMap;
271     ConnectMessage->SectionToMap = NULL;
272     KeReleaseGuardedMutex(&LpcpLock);
273 
274     /* Now check if there's a client section */
275     if (ClientSectionToMap)
276     {
277         /* Setup the offset */
278         SectionOffset.QuadPart = ConnectMessage->ClientView.SectionOffset;
279 
280         /* Map the section */
281         Status = MmMapViewOfSection(ClientSectionToMap,
282                                     PsGetCurrentProcess(),
283                                     &ServerPort->ClientSectionBase,
284                                     0,
285                                     0,
286                                     &SectionOffset,
287                                     &ConnectMessage->ClientView.ViewSize,
288                                     ViewUnmap,
289                                     0,
290                                     PAGE_READWRITE);
291 
292         /* Update the offset and check for mapping status */
293         ConnectMessage->ClientView.SectionOffset = SectionOffset.LowPart;
294         if (NT_SUCCESS(Status))
295         {
296             /* Set the view base */
297             ConnectMessage->ClientView.ViewRemoteBase = ServerPort->
298                                                         ClientSectionBase;
299 
300             /* Save and reference the mapping process */
301             ServerPort->MappingProcess = PsGetCurrentProcess();
302             ObReferenceObject(ServerPort->MappingProcess);
303         }
304         else
305         {
306             /* Otherwise, quit */
307             ObDereferenceObject(ServerPort);
308             DPRINT1("Client section mapping failed: %lx\n", Status);
309             DPRINT1("View base, offset, size: %p %lx %p\n",
310                     ServerPort->ClientSectionBase,
311                     ConnectMessage->ClientView.ViewSize,
312                     SectionOffset);
313             goto Cleanup;
314         }
315     }
316 
317     /* Check if there's a server section */
318     if (ServerView)
319     {
320         /* FIXME: TODO */
321         UNREFERENCED_PARAMETER(CapturedServerView);
322         ASSERT(FALSE);
323     }
324 
325     /* Reference the server port until it's fully inserted */
326     ObReferenceObject(ServerPort);
327 
328     /* Insert the server port in the namespace */
329     Status = ObInsertObject(ServerPort,
330                             NULL,
331                             PORT_ALL_ACCESS,
332                             0,
333                             NULL,
334                             &Handle);
335     if (!NT_SUCCESS(Status))
336     {
337         /* We failed, remove the extra reference and cleanup */
338         ObDereferenceObject(ServerPort);
339         goto Cleanup;
340     }
341 
342     /* Enter SEH to write back the results */
343     _SEH2_TRY
344     {
345         /* Check if the caller gave a client view */
346         if (ClientView)
347         {
348             /* Fill it out */
349             ClientView->ViewBase = ConnectMessage->ClientView.ViewRemoteBase;
350             ClientView->ViewSize = ConnectMessage->ClientView.ViewSize;
351         }
352 
353         /* Return the handle to user mode */
354         *PortHandle = Handle;
355     }
356     _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
357     {
358         /* Cleanup and return the exception code */
359         ObCloseHandle(Handle, PreviousMode);
360         ObDereferenceObject(ServerPort);
361         Status = _SEH2_GetExceptionCode();
362         _SEH2_YIELD(goto Cleanup);
363     }
364     _SEH2_END;
365 
366     LPCTRACE(LPC_COMPLETE_DEBUG,
367              "Handle: %p. Messages: %p/%p. Ports: %p/%p/%p\n",
368              Handle,
369              Message,
370              ConnectMessage,
371              ServerPort,
372              ClientPort,
373              ConnectionPort);
374 
375     /* If there was no port context, use the handle by default */
376     if (!PortContext) ServerPort->PortContext = Handle;
377     ServerPort->ClientThread = ClientThread;
378 
379     /* Set this message as the LPC Reply message while holding the lock */
380     KeAcquireGuardedMutex(&LpcpLock);
381     ClientThread->LpcReplyMessage = Message;
382     KeReleaseGuardedMutex(&LpcpLock);
383 
384     /* Clear the thread pointer so it doesn't get cleaned later */
385     ClientThread = NULL;
386 
387     /* Remove the extra reference we had added */
388     ObDereferenceObject(ServerPort);
389 
390 Cleanup:
391     /* If there was a section, dereference it */
392     if (ClientSectionToMap) ObDereferenceObject(ClientSectionToMap);
393 
394     /* Check if we got here while still having a client thread */
395     if (ClientThread)
396     {
397         KeAcquireGuardedMutex(&LpcpLock);
398         ClientThread->LpcReplyMessage = Message;
399         LpcpPrepareToWakeClient(ClientThread);
400         KeReleaseGuardedMutex(&LpcpLock);
401         LpcpCompleteWait(&ClientThread->LpcReplySemaphore);
402         ObDereferenceObject(ClientThread);
403     }
404 
405     /* Dereference the client port if we have one, and the process */
406     LPCTRACE(LPC_COMPLETE_DEBUG,
407              "Status: %lx. Thread: %p. Process: [%.16s]\n",
408              Status,
409              ClientThread,
410              ClientProcess->ImageFileName);
411     if (ClientPort) ObDereferenceObject(ClientPort);
412     ObDereferenceObject(ClientProcess);
413     return Status;
414 }
415 
416 /*
417  * @implemented
418  */
419 NTSTATUS
420 NTAPI
421 NtCompleteConnectPort(IN HANDLE PortHandle)
422 {
423     NTSTATUS Status;
424     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
425     PLPCP_PORT_OBJECT Port;
426     PETHREAD Thread;
427 
428     PAGED_CODE();
429     LPCTRACE(LPC_COMPLETE_DEBUG, "Handle: %p\n", PortHandle);
430 
431     /* Get the Port Object */
432     Status = ObReferenceObjectByHandle(PortHandle,
433                                        PORT_ALL_ACCESS,
434                                        LpcPortObjectType,
435                                        PreviousMode,
436                                        (PVOID*)&Port,
437                                        NULL);
438     if (!NT_SUCCESS(Status)) return Status;
439 
440     /* Make sure this is a connection port */
441     if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT)
442     {
443         /* It isn't, fail */
444         ObDereferenceObject(Port);
445         return STATUS_INVALID_PORT_HANDLE;
446     }
447 
448     /* Acquire the lock */
449     KeAcquireGuardedMutex(&LpcpLock);
450 
451     /* Make sure we have a client thread */
452     if (!Port->ClientThread)
453     {
454         /* We don't, fail */
455         KeReleaseGuardedMutex(&LpcpLock);
456         ObDereferenceObject(Port);
457         return STATUS_INVALID_PARAMETER;
458     }
459 
460     /* Get the thread */
461     Thread = Port->ClientThread;
462 
463     /* Make sure it has a reply message */
464     if (!LpcpGetMessageFromThread(Thread))
465     {
466         /* It doesn't, quit */
467         KeReleaseGuardedMutex(&LpcpLock);
468         ObDereferenceObject(Port);
469         return STATUS_SUCCESS;
470     }
471 
472     /* Clear the client thread and wake it up */
473     Port->ClientThread = NULL;
474     LpcpPrepareToWakeClient(Thread);
475 
476     /* Release the lock and wait for an answer */
477     KeReleaseGuardedMutex(&LpcpLock);
478     LpcpCompleteWait(&Thread->LpcReplySemaphore);
479 
480     /* Dereference the Thread and Port and return */
481     ObDereferenceObject(Port);
482     ObDereferenceObject(Thread);
483     LPCTRACE(LPC_COMPLETE_DEBUG, "Port: %p. Thread: %p\n", Port, Thread);
484     return Status;
485 }
486 
487 /* EOF */
488