xref: /reactos/ntoskrnl/lpc/complete.c (revision 84344399)
1 /*
2 * PROJECT:         ReactOS Kernel
3 * LICENSE:         GPL - See COPYING in the top level directory
4 * FILE:            ntoskrnl/lpc/complete.c
5 * PURPOSE:         Local Procedure Call: Connection Completion
6 * PROGRAMMERS:     Alex Ionescu (alex.ionescu@reactos.org)
7 */
8 
9 /* INCLUDES ******************************************************************/
10 
11 #include <ntoskrnl.h>
12 #define NDEBUG
13 #include <debug.h>
14 
15 /* PRIVATE FUNCTIONS *********************************************************/
16 
17 VOID
18 NTAPI
19 LpcpPrepareToWakeClient(IN PETHREAD Thread)
20 {
21     PAGED_CODE();
22 
23     /* Make sure the thread isn't dying and it has a valid chain */
24     if (!(Thread->LpcExitThreadCalled) &&
25         !(IsListEmpty(&Thread->LpcReplyChain)))
26     {
27         /* Remove it from the list and reinitialize it */
28         RemoveEntryList(&Thread->LpcReplyChain);
29         InitializeListHead(&Thread->LpcReplyChain);
30     }
31 }
32 
33 /* PUBLIC FUNCTIONS **********************************************************/
34 
35 /*
36  * @implemented
37  */
38 NTSTATUS
39 NTAPI
40 NtAcceptConnectPort(OUT PHANDLE PortHandle,
41                     IN PVOID PortContext OPTIONAL,
42                     IN PPORT_MESSAGE ReplyMessage,
43                     IN BOOLEAN AcceptConnection,
44                     IN OUT PPORT_VIEW ServerView OPTIONAL,
45                     OUT PREMOTE_PORT_VIEW ClientView OPTIONAL)
46 {
47     NTSTATUS Status;
48     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
49     PORT_VIEW CapturedServerView;
50     PORT_MESSAGE CapturedReplyMessage;
51     ULONG ConnectionInfoLength;
52     PLPCP_PORT_OBJECT ConnectionPort, ServerPort, ClientPort;
53     PLPCP_CONNECTION_MESSAGE ConnectMessage;
54     PLPCP_MESSAGE Message;
55     PVOID ClientSectionToMap = NULL;
56     HANDLE Handle;
57     PEPROCESS ClientProcess;
58     PETHREAD ClientThread;
59     LARGE_INTEGER SectionOffset;
60 
61     PAGED_CODE();
62 
63     LPCTRACE(LPC_COMPLETE_DEBUG,
64              "Context: %p. Message: %p. Accept: %lx. Views: %p/%p\n",
65              PortContext,
66              ReplyMessage,
67              AcceptConnection,
68              ClientView,
69              ServerView);
70 
71     /* Check if the call comes from user mode */
72     if (PreviousMode != KernelMode)
73     {
74         _SEH2_TRY
75         {
76             /* Probe the PortHandle */
77             ProbeForWriteHandle(PortHandle);
78 
79             /* Probe the basic ReplyMessage structure */
80             ProbeForRead(ReplyMessage, sizeof(*ReplyMessage), sizeof(ULONG));
81             CapturedReplyMessage = *(volatile PORT_MESSAGE*)ReplyMessage;
82             ConnectionInfoLength = CapturedReplyMessage.u1.s1.DataLength;
83 
84             /* Probe the connection info */
85             ProbeForRead(ReplyMessage + 1, ConnectionInfoLength, 1);
86 
87             /* The following parameters are optional */
88 
89             /* Capture the server view */
90             if (ServerView)
91             {
92                 ProbeForWrite(ServerView, sizeof(*ServerView), sizeof(ULONG));
93                 CapturedServerView = *(volatile PORT_VIEW*)ServerView;
94 
95                 /* Validate the size of the server view */
96                 if (CapturedServerView.Length != sizeof(CapturedServerView))
97                 {
98                     /* Invalid size */
99                     _SEH2_YIELD(return STATUS_INVALID_PARAMETER);
100                 }
101             }
102 
103             /* Capture the client view */
104             if (ClientView)
105             {
106                 ProbeForWrite(ClientView, sizeof(*ClientView), sizeof(ULONG));
107 
108                 /* Validate the size of the client view */
109                 if (((volatile REMOTE_PORT_VIEW*)ClientView)->Length != sizeof(*ClientView))
110                 {
111                     /* Invalid size */
112                     _SEH2_YIELD(return STATUS_INVALID_PARAMETER);
113                 }
114             }
115         }
116         _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
117         {
118             /* There was an exception, return the exception code */
119             _SEH2_YIELD(return _SEH2_GetExceptionCode());
120         }
121         _SEH2_END;
122     }
123     else
124     {
125         CapturedReplyMessage = *ReplyMessage;
126         ConnectionInfoLength = CapturedReplyMessage.u1.s1.DataLength;
127 
128         /* Capture the server view */
129         if (ServerView)
130         {
131             /* Validate the size of the server view */
132             if (ServerView->Length != sizeof(*ServerView))
133             {
134                 /* Invalid size */
135                 return STATUS_INVALID_PARAMETER;
136             }
137             CapturedServerView = *ServerView;
138         }
139 
140         /* Capture the client view */
141         if (ClientView)
142         {
143             /* Validate the size of the client view */
144             if (ClientView->Length != sizeof(*ClientView))
145             {
146                 /* Invalid size */
147                 return STATUS_INVALID_PARAMETER;
148             }
149         }
150     }
151 
152     /* Get the client process and thread */
153     Status = PsLookupProcessThreadByCid(&CapturedReplyMessage.ClientId,
154                                         &ClientProcess,
155                                         &ClientThread);
156     if (!NT_SUCCESS(Status)) return Status;
157 
158     /* Acquire the LPC Lock */
159     KeAcquireGuardedMutex(&LpcpLock);
160 
161     /* Make sure that the client wants a reply, and this is the right one */
162     if (!(LpcpGetMessageFromThread(ClientThread)) ||
163         !(CapturedReplyMessage.MessageId) ||
164         (ClientThread->LpcReplyMessageId != CapturedReplyMessage.MessageId))
165     {
166         /* Not the reply asked for, or no reply wanted, fail */
167         KeReleaseGuardedMutex(&LpcpLock);
168         ObDereferenceObject(ClientProcess);
169         ObDereferenceObject(ClientThread);
170         return STATUS_REPLY_MESSAGE_MISMATCH;
171     }
172 
173     /* Now get the message and connection message */
174     Message = LpcpGetMessageFromThread(ClientThread);
175     ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
176 
177     /* Get the client and connection port as well */
178     ClientPort = ConnectMessage->ClientPort;
179     ConnectionPort = ClientPort->ConnectionPort;
180 
181     /* Make sure that the reply is being sent to the proper server process */
182     if (ConnectionPort->ServerProcess != PsGetCurrentProcess())
183     {
184         /* It's not, so fail */
185         KeReleaseGuardedMutex(&LpcpLock);
186         ObDereferenceObject(ClientProcess);
187         ObDereferenceObject(ClientThread);
188         return STATUS_REPLY_MESSAGE_MISMATCH;
189     }
190 
191     /* At this point, don't let other accept attempts happen */
192     ClientThread->LpcReplyMessage = NULL;
193     ClientThread->LpcReplyMessageId = 0;
194 
195     /* Clear the client port for now as well, then release the lock */
196     ConnectMessage->ClientPort = NULL;
197     KeReleaseGuardedMutex(&LpcpLock);
198 
199     /* Check the connection information length */
200     if (ConnectionInfoLength > ConnectionPort->MaxConnectionInfoLength)
201     {
202         /* Normalize it since it's too large */
203         ConnectionInfoLength = ConnectionPort->MaxConnectionInfoLength;
204     }
205 
206     /* Set the sizes of our reply message */
207     Message->Request.u1.s1.DataLength = (CSHORT)ConnectionInfoLength +
208                                          sizeof(LPCP_CONNECTION_MESSAGE);
209     Message->Request.u1.s1.TotalLength = sizeof(LPCP_MESSAGE) +
210                                      Message->Request.u1.s1.DataLength;
211 
212     /* Setup the reply message */
213     Message->Request.u2.s2.Type = LPC_REPLY;
214     Message->Request.u2.s2.DataInfoOffset = 0;
215     Message->Request.ClientId  = CapturedReplyMessage.ClientId;
216     Message->Request.MessageId = CapturedReplyMessage.MessageId;
217     Message->Request.ClientViewSize = 0;
218 
219     _SEH2_TRY
220     {
221         RtlCopyMemory(ConnectMessage + 1, ReplyMessage + 1, ConnectionInfoLength);
222     }
223     _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
224     {
225         Status = _SEH2_GetExceptionCode();
226         _SEH2_YIELD(goto Cleanup);
227     }
228     _SEH2_END;
229 
230     /* At this point, if the caller refused the connection, go to cleanup */
231     if (!AcceptConnection)
232     {
233         DPRINT1("LPC connection was refused\n");
234         goto Cleanup;
235     }
236 
237     /* Otherwise, create the actual port */
238     Status = ObCreateObject(PreviousMode,
239                             LpcPortObjectType,
240                             NULL,
241                             PreviousMode,
242                             NULL,
243                             sizeof(LPCP_PORT_OBJECT),
244                             0,
245                             0,
246                             (PVOID*)&ServerPort);
247     if (!NT_SUCCESS(Status)) goto Cleanup;
248 
249     /* Set it up */
250     RtlZeroMemory(ServerPort, sizeof(LPCP_PORT_OBJECT));
251     ServerPort->PortContext = PortContext;
252     ServerPort->Flags = LPCP_COMMUNICATION_PORT;
253     ServerPort->MaxMessageLength = ConnectionPort->MaxMessageLength;
254     InitializeListHead(&ServerPort->LpcReplyChainHead);
255     InitializeListHead(&ServerPort->LpcDataInfoChainHead);
256 
257     /* Reference the connection port until we're fully setup */
258     ObReferenceObject(ConnectionPort);
259 
260     /* Link the ports together */
261     ServerPort->ConnectionPort = ConnectionPort;
262     ServerPort->ConnectedPort = ClientPort;
263     ClientPort->ConnectedPort = ServerPort;
264 
265     /* Also set the creator CID */
266     ServerPort->Creator = PsGetCurrentThread()->Cid;
267     ClientPort->Creator = Message->Request.ClientId;
268 
269     /* Get the section associated and then clear it, while inside the lock */
270     KeAcquireGuardedMutex(&LpcpLock);
271     ClientSectionToMap = ConnectMessage->SectionToMap;
272     ConnectMessage->SectionToMap = NULL;
273     KeReleaseGuardedMutex(&LpcpLock);
274 
275     /* Now check if there's a client section */
276     if (ClientSectionToMap)
277     {
278         /* Setup the offset */
279         SectionOffset.QuadPart = ConnectMessage->ClientView.SectionOffset;
280 
281         /* Map the section */
282         Status = MmMapViewOfSection(ClientSectionToMap,
283                                     PsGetCurrentProcess(),
284                                     &ServerPort->ClientSectionBase,
285                                     0,
286                                     0,
287                                     &SectionOffset,
288                                     &ConnectMessage->ClientView.ViewSize,
289                                     ViewUnmap,
290                                     0,
291                                     PAGE_READWRITE);
292 
293         /* Update the offset and check for mapping status */
294         ConnectMessage->ClientView.SectionOffset = SectionOffset.LowPart;
295         if (NT_SUCCESS(Status))
296         {
297             /* Set the view base */
298             ConnectMessage->ClientView.ViewRemoteBase = ServerPort->
299                                                         ClientSectionBase;
300 
301             /* Save and reference the mapping process */
302             ServerPort->MappingProcess = PsGetCurrentProcess();
303             ObReferenceObject(ServerPort->MappingProcess);
304         }
305         else
306         {
307             /* Otherwise, quit */
308             ObDereferenceObject(ServerPort);
309             DPRINT1("Client section mapping failed: %lx\n", Status);
310             LPCTRACE(LPC_COMPLETE_DEBUG,
311                      "View base, offset, size: %p %lx %p\n",
312                      ServerPort->ClientSectionBase,
313                      ConnectMessage->ClientView.ViewSize,
314                      SectionOffset);
315             goto Cleanup;
316         }
317     }
318 
319     /* Check if there's a server section */
320     if (ServerView)
321     {
322         /* FIXME: TODO */
323         UNREFERENCED_PARAMETER(CapturedServerView);
324         ASSERT(FALSE);
325     }
326 
327     /* Reference the server port until it's fully inserted */
328     ObReferenceObject(ServerPort);
329 
330     /* Insert the server port in the namespace */
331     Status = ObInsertObject(ServerPort,
332                             NULL,
333                             PORT_ALL_ACCESS,
334                             0,
335                             NULL,
336                             &Handle);
337     if (!NT_SUCCESS(Status))
338     {
339         /* We failed, remove the extra reference and cleanup */
340         ObDereferenceObject(ServerPort);
341         goto Cleanup;
342     }
343 
344     /* Enter SEH to write back the results */
345     _SEH2_TRY
346     {
347         /* Check if the caller gave a client view */
348         if (ClientView)
349         {
350             /* Fill it out */
351             ClientView->ViewBase = ConnectMessage->ClientView.ViewRemoteBase;
352             ClientView->ViewSize = ConnectMessage->ClientView.ViewSize;
353         }
354 
355         /* Return the handle to user mode */
356         *PortHandle = Handle;
357     }
358     _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
359     {
360         /* Cleanup and return the exception code */
361         ObCloseHandle(Handle, PreviousMode);
362         ObDereferenceObject(ServerPort);
363         Status = _SEH2_GetExceptionCode();
364         _SEH2_YIELD(goto Cleanup);
365     }
366     _SEH2_END;
367 
368     LPCTRACE(LPC_COMPLETE_DEBUG,
369              "Handle: %p. Messages: %p/%p. Ports: %p/%p/%p\n",
370              Handle,
371              Message,
372              ConnectMessage,
373              ServerPort,
374              ClientPort,
375              ConnectionPort);
376 
377     /* If there was no port context, use the handle by default */
378     if (!PortContext) ServerPort->PortContext = Handle;
379     ServerPort->ClientThread = ClientThread;
380 
381     /* Set this message as the LPC Reply message while holding the lock */
382     KeAcquireGuardedMutex(&LpcpLock);
383     ClientThread->LpcReplyMessage = Message;
384     KeReleaseGuardedMutex(&LpcpLock);
385 
386     /* Clear the thread pointer so it doesn't get cleaned later */
387     ClientThread = NULL;
388 
389     /* Remove the extra reference we had added */
390     ObDereferenceObject(ServerPort);
391 
392 Cleanup:
393     /* If there was a section, dereference it */
394     if (ClientSectionToMap) ObDereferenceObject(ClientSectionToMap);
395 
396     /* Check if we got here while still having a client thread */
397     if (ClientThread)
398     {
399         KeAcquireGuardedMutex(&LpcpLock);
400         ClientThread->LpcReplyMessage = Message;
401         LpcpPrepareToWakeClient(ClientThread);
402         KeReleaseGuardedMutex(&LpcpLock);
403         LpcpCompleteWait(&ClientThread->LpcReplySemaphore);
404         ObDereferenceObject(ClientThread);
405     }
406 
407     /* Dereference the client port if we have one, and the process */
408     LPCTRACE(LPC_COMPLETE_DEBUG,
409              "Status: %lx. Thread: %p. Process: [%.16s]\n",
410              Status,
411              ClientThread,
412              ClientProcess->ImageFileName);
413     if (ClientPort) ObDereferenceObject(ClientPort);
414     ObDereferenceObject(ClientProcess);
415     return Status;
416 }
417 
418 /*
419  * @implemented
420  */
421 NTSTATUS
422 NTAPI
423 NtCompleteConnectPort(IN HANDLE PortHandle)
424 {
425     NTSTATUS Status;
426     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
427     PLPCP_PORT_OBJECT Port;
428     PETHREAD Thread;
429 
430     PAGED_CODE();
431     LPCTRACE(LPC_COMPLETE_DEBUG, "Handle: %p\n", PortHandle);
432 
433     /* Get the Port Object */
434     Status = ObReferenceObjectByHandle(PortHandle,
435                                        PORT_ALL_ACCESS,
436                                        LpcPortObjectType,
437                                        PreviousMode,
438                                        (PVOID*)&Port,
439                                        NULL);
440     if (!NT_SUCCESS(Status)) return Status;
441 
442     /* Make sure this is a connection port */
443     if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT)
444     {
445         /* It isn't, fail */
446         ObDereferenceObject(Port);
447         return STATUS_INVALID_PORT_HANDLE;
448     }
449 
450     /* Acquire the lock */
451     KeAcquireGuardedMutex(&LpcpLock);
452 
453     /* Make sure we have a client thread */
454     if (!Port->ClientThread)
455     {
456         /* We don't, fail */
457         KeReleaseGuardedMutex(&LpcpLock);
458         ObDereferenceObject(Port);
459         return STATUS_INVALID_PARAMETER;
460     }
461 
462     /* Get the thread */
463     Thread = Port->ClientThread;
464 
465     /* Make sure it has a reply message */
466     if (!LpcpGetMessageFromThread(Thread))
467     {
468         /* It doesn't, quit */
469         KeReleaseGuardedMutex(&LpcpLock);
470         ObDereferenceObject(Port);
471         return STATUS_SUCCESS;
472     }
473 
474     /* Clear the client thread and wake it up */
475     Port->ClientThread = NULL;
476     LpcpPrepareToWakeClient(Thread);
477 
478     /* Release the lock and wait for an answer */
479     KeReleaseGuardedMutex(&LpcpLock);
480     LpcpCompleteWait(&Thread->LpcReplySemaphore);
481 
482     /* Dereference the Thread and Port and return */
483     ObDereferenceObject(Port);
484     ObDereferenceObject(Thread);
485     LPCTRACE(LPC_COMPLETE_DEBUG, "Port: %p. Thread: %p\n", Port, Thread);
486     return Status;
487 }
488 
489 /* EOF */
490