xref: /reactos/subsystems/csr/csrlib/connect.c (revision 0b366ea1)
1 /*
2  * PROJECT:     ReactOS Client/Server Runtime SubSystem
3  * LICENSE:     GPL-2.0-or-later (https://spdx.org/licenses/GPL-2.0-or-later)
4  * PURPOSE:     CSR Client Library - CSR connection and calling
5  * COPYRIGHT:   Copyright 2005-2013 Alex Ionescu <alex@relsoft.net>
6  *              Copyright 2012-2022 Hermès Bélusca-Maïto <hermes.belusca-maito@reactos.org>
7  */
8 
9 /* INCLUDES *******************************************************************/
10 
11 #include "csrlib.h"
12 
13 #define NTOS_MODE_USER
14 #include <ndk/ldrfuncs.h>
15 #include <ndk/lpcfuncs.h>
16 #include <ndk/mmfuncs.h>
17 #include <ndk/obfuncs.h>
18 #include <ndk/umfuncs.h>
19 
20 #include <csrsrv.h> // For CSR_CSRSS_SECTION_SIZE
21 
22 #define NDEBUG
23 #include <debug.h>
24 
25 /* GLOBALS ********************************************************************/
26 
27 HANDLE CsrApiPort;
28 HANDLE CsrProcessId;
29 HANDLE CsrPortHeap;
30 ULONG_PTR CsrPortMemoryDelta;
31 BOOLEAN InsideCsrProcess = FALSE;
32 
33 typedef NTSTATUS
34 (NTAPI *PCSR_SERVER_API_ROUTINE)(
35     _In_ PCSR_API_MESSAGE Request,
36     _Inout_ PCSR_API_MESSAGE Reply);
37 
38 PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;
39 
40 /* FUNCTIONS ******************************************************************/
41 
42 static NTSTATUS
43 CsrpConnectToServer(
44     _In_ PCWSTR ObjectDirectory)
45 {
46     NTSTATUS Status;
47     SIZE_T PortNameLength;
48     UNICODE_STRING PortName;
49     LARGE_INTEGER CsrSectionViewSize;
50     HANDLE CsrSectionHandle;
51     PORT_VIEW LpcWrite;
52     REMOTE_PORT_VIEW LpcRead;
53     SECURITY_QUALITY_OF_SERVICE SecurityQos;
54     SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};
55     PSID SystemSid = NULL;
56     CSR_API_CONNECTINFO ConnectionInfo;
57     ULONG ConnectionInfoLength = sizeof(ConnectionInfo);
58 
59     DPRINT("%s(%S)\n", __FUNCTION__, ObjectDirectory);
60 
61     /* Calculate the total port name size */
62     PortNameLength = ((wcslen(ObjectDirectory) + 1) * sizeof(WCHAR)) +
63                      sizeof(CSR_PORT_NAME);
64     if (PortNameLength > UNICODE_STRING_MAX_BYTES)
65     {
66         DPRINT1("PortNameLength too big: %Iu\n", PortNameLength);
67         return STATUS_NAME_TOO_LONG;
68     }
69 
70     /* Set the port name */
71     PortName.Length = 0;
72     PortName.MaximumLength = (USHORT)PortNameLength;
73 
74     /* Allocate a buffer for it */
75     PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), HEAP_ZERO_MEMORY, PortNameLength);
76     if (PortName.Buffer == NULL)
77     {
78         return STATUS_INSUFFICIENT_RESOURCES;
79     }
80 
81     /* Create the name */
82     RtlAppendUnicodeToString(&PortName, ObjectDirectory);
83     RtlAppendUnicodeToString(&PortName, L"\\");
84     RtlAppendUnicodeToString(&PortName, CSR_PORT_NAME);
85 
86     /* Create a section for the port memory */
87     CsrSectionViewSize.QuadPart = CSR_CSRSS_SECTION_SIZE;
88     Status = NtCreateSection(&CsrSectionHandle,
89                              SECTION_ALL_ACCESS,
90                              NULL,
91                              &CsrSectionViewSize,
92                              PAGE_READWRITE,
93                              SEC_RESERVE,
94                              NULL);
95     if (!NT_SUCCESS(Status))
96     {
97         DPRINT1("Failure allocating CSR Section\n");
98         return Status;
99     }
100 
101     /* Set up the port view structures to match them with the section */
102     LpcWrite.Length = sizeof(LpcWrite);
103     LpcWrite.SectionHandle = CsrSectionHandle;
104     LpcWrite.SectionOffset = 0;
105     LpcWrite.ViewSize = CsrSectionViewSize.u.LowPart;
106     LpcWrite.ViewBase = 0;
107     LpcWrite.ViewRemoteBase = 0;
108     LpcRead.Length = sizeof(LpcRead);
109     LpcRead.ViewSize = 0;
110     LpcRead.ViewBase = 0;
111 
112     /* Setup the QoS */
113     SecurityQos.ImpersonationLevel = SecurityImpersonation;
114     SecurityQos.ContextTrackingMode = SECURITY_DYNAMIC_TRACKING;
115     SecurityQos.EffectiveOnly = TRUE;
116 
117     /* Setup the connection info */
118     ConnectionInfo.DebugFlags = 0;
119 
120     /* Create a SID for us */
121     Status = RtlAllocateAndInitializeSid(&NtSidAuthority,
122                                          1,
123                                          SECURITY_LOCAL_SYSTEM_RID,
124                                          0,
125                                          0,
126                                          0,
127                                          0,
128                                          0,
129                                          0,
130                                          0,
131                                          &SystemSid);
132     if (!NT_SUCCESS(Status))
133     {
134         /* Failure */
135         DPRINT1("Couldn't allocate SID\n");
136         NtClose(CsrSectionHandle);
137         return Status;
138     }
139 
140     /* Connect to the port */
141     Status = NtSecureConnectPort(&CsrApiPort,
142                                  &PortName,
143                                  &SecurityQos,
144                                  &LpcWrite,
145                                  SystemSid,
146                                  &LpcRead,
147                                  NULL,
148                                  &ConnectionInfo,
149                                  &ConnectionInfoLength);
150     RtlFreeSid(SystemSid);
151     NtClose(CsrSectionHandle);
152     if (!NT_SUCCESS(Status))
153     {
154         /* Failure */
155         DPRINT1("Couldn't connect to CSR port\n");
156         return Status;
157     }
158 
159     /* Save the delta between the sections, for capture usage later */
160     CsrPortMemoryDelta = (ULONG_PTR)LpcWrite.ViewRemoteBase -
161                          (ULONG_PTR)LpcWrite.ViewBase;
162 
163     /* Save the Process */
164     CsrProcessId = ConnectionInfo.ServerProcessId;
165 
166     /* Save CSR Section data */
167     NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;
168     NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;
169     NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedStaticServerData;
170 
171     /* Create the port heap */
172     CsrPortHeap = RtlCreateHeap(0,
173                                 LpcWrite.ViewBase,
174                                 LpcWrite.ViewSize,
175                                 PAGE_SIZE,
176                                 0,
177                                 0);
178     if (CsrPortHeap == NULL)
179     {
180         /* Failure */
181         DPRINT1("Couldn't create heap for CSR port\n");
182         NtClose(CsrApiPort);
183         CsrApiPort = NULL;
184         return STATUS_INSUFFICIENT_RESOURCES;
185     }
186 
187     /* Return success */
188     return STATUS_SUCCESS;
189 }
190 
191 /*
192  * @implemented
193  */
194 NTSTATUS
195 NTAPI
196 CsrClientConnectToServer(
197     _In_ PCWSTR ObjectDirectory,
198     _In_ ULONG ServerId,
199     _In_ PVOID ConnectionInfo,
200     _Inout_ PULONG ConnectionInfoSize,
201     _Out_ PBOOLEAN ServerToServerCall)
202 {
203     NTSTATUS Status;
204     PIMAGE_NT_HEADERS NtHeader;
205 
206     DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);
207 
208     /* Validate the Connection Info */
209     if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))
210     {
211         DPRINT1("Connection info given, but no length\n");
212         return STATUS_INVALID_PARAMETER;
213     }
214 
215     /* Check if we're inside a CSR Process */
216     if (InsideCsrProcess)
217     {
218         /* Tell the client that we're already inside CSR */
219         if (ServerToServerCall) *ServerToServerCall = TRUE;
220         return STATUS_SUCCESS;
221     }
222 
223     /*
224      * We might be in a CSR Process but not know it, if this is the first call.
225      * So let's find out.
226      */
227     if (!(NtHeader = RtlImageNtHeader(NtCurrentPeb()->ImageBaseAddress)))
228     {
229         /* The image isn't valid */
230         DPRINT1("Invalid image\n");
231         return STATUS_INVALID_IMAGE_FORMAT;
232     }
233     InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);
234 
235     /* Now we can check if we are inside or not */
236     if (InsideCsrProcess)
237     {
238         UNICODE_STRING CsrSrvName;
239         HANDLE hCsrSrv;
240         ANSI_STRING CsrServerRoutineName;
241 
242         /* We're inside, so let's find csrsrv */
243         RtlInitUnicodeString(&CsrSrvName, L"csrsrv");
244         Status = LdrGetDllHandle(NULL,
245                                  NULL,
246                                  &CsrSrvName,
247                                  &hCsrSrv);
248 
249         /* Now get the Server to Server routine */
250         RtlInitAnsiString(&CsrServerRoutineName, "CsrCallServerFromServer");
251         Status = LdrGetProcedureAddress(hCsrSrv,
252                                         &CsrServerRoutineName,
253                                         0L,
254                                         (PVOID*)&CsrServerApiRoutine);
255 
256         /* Use the local heap as port heap */
257         CsrPortHeap = RtlGetProcessHeap();
258 
259         /* Tell the caller we're inside the server */
260         if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
261         return STATUS_SUCCESS;
262     }
263 
264     /* Now check if connection info is given */
265     if (ConnectionInfo)
266     {
267         CSR_API_MESSAGE ApiMessage;
268         PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.Data.CsrClientConnect;
269         PCSR_CAPTURE_BUFFER CaptureBuffer;
270 
271         /* Well, we're definitely in a client now */
272         InsideCsrProcess = FALSE;
273 
274         /* Do we have a connection to CSR yet? */
275         if (!CsrApiPort)
276         {
277             /* No, set it up now */
278             Status = CsrpConnectToServer(ObjectDirectory);
279             if (!NT_SUCCESS(Status))
280             {
281                 /* Failed */
282                 DPRINT1("Failure to connect to CSR\n");
283                 return Status;
284             }
285         }
286 
287         /* Setup the connect message header */
288         ClientConnect->ServerId = ServerId;
289         ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;
290 
291         /* Setup a buffer for the connection info */
292         CaptureBuffer = CsrAllocateCaptureBuffer(1, ClientConnect->ConnectionInfoSize);
293         if (CaptureBuffer == NULL)
294         {
295             return STATUS_INSUFFICIENT_RESOURCES;
296         }
297 
298         /* Capture the connection info data */
299         CsrCaptureMessageBuffer(CaptureBuffer,
300                                 ConnectionInfo,
301                                 ClientConnect->ConnectionInfoSize,
302                                 &ClientConnect->ConnectionInfo);
303 
304         /* Return the allocated length */
305         *ConnectionInfoSize = ClientConnect->ConnectionInfoSize;
306 
307         /* Call CSR */
308         Status = CsrClientCallServer(&ApiMessage,
309                                      CaptureBuffer,
310                                      CSR_CREATE_API_NUMBER(CSRSRV_SERVERDLL_INDEX, CsrpClientConnect),
311                                      sizeof(*ClientConnect));
312 
313         /* Copy the updated connection info data back into the user buffer */
314         RtlMoveMemory(ConnectionInfo,
315                       ClientConnect->ConnectionInfo,
316                       *ConnectionInfoSize);
317 
318         /* Free the capture buffer */
319         CsrFreeCaptureBuffer(CaptureBuffer);
320     }
321     else
322     {
323         /* No connection info, just return */
324         Status = STATUS_SUCCESS;
325     }
326 
327     /* Let the caller know if this was server to server */
328     DPRINT("Status was: 0x%lx. Are we in server: 0x%x\n", Status, InsideCsrProcess);
329     if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
330 
331     return Status;
332 }
333 
334 #if 0
335 //
336 // Structures can be padded at the end, causing the size of the entire structure
337 // minus the size of the last field, not to be equal to the offset of the last
338 // field.
339 //
340 typedef struct _TEST_EMBEDDED
341 {
342     ULONG One;
343     ULONG Two;
344     ULONG Three;
345 } TEST_EMBEDDED;
346 
347 typedef struct _TEST
348 {
349     PORT_MESSAGE h;
350     TEST_EMBEDDED Three;
351 } TEST;
352 
353 C_ASSERT(sizeof(PORT_MESSAGE) == 0x18);
354 C_ASSERT(FIELD_OFFSET(TEST, Three) == 0x18);
355 C_ASSERT(sizeof(TEST_EMBEDDED) == 0xC);
356 
357 C_ASSERT(sizeof(TEST) != (sizeof(TEST_EMBEDDED) + sizeof(PORT_MESSAGE)));
358 C_ASSERT((sizeof(TEST) - sizeof(TEST_EMBEDDED)) != FIELD_OFFSET(TEST, Three));
359 #endif
360 
361 /*
362  * @implemented
363  */
364 NTSTATUS
365 NTAPI
366 CsrClientCallServer(
367     _Inout_ PCSR_API_MESSAGE ApiMessage,
368     _Inout_opt_ PCSR_CAPTURE_BUFFER CaptureBuffer,
369     _In_ CSR_API_NUMBER ApiNumber,
370     _In_ ULONG DataLength)
371 {
372     NTSTATUS Status;
373 
374     /* Make sure the length is valid */
375     if (DataLength > (MAXSHORT - sizeof(CSR_API_MESSAGE)))
376     {
377         DPRINT1("DataLength too big: %lu\n", DataLength);
378         return STATUS_INVALID_PARAMETER;
379     }
380 
381     /* Fill out the Port Message Header */
382     ApiMessage->Header.u2.ZeroInit = 0;
383     /* DataLength = user_data_size + anything between
384      * header and data, including intermediate padding */
385     ApiMessage->Header.u1.s1.DataLength = (CSHORT)DataLength +
386         FIELD_OFFSET(CSR_API_MESSAGE, Data) - sizeof(ApiMessage->Header);
387     /* TotalLength = header_size + DataLength + any structure trailing padding */
388     ApiMessage->Header.u1.s1.TotalLength = (CSHORT)DataLength +
389         sizeof(CSR_API_MESSAGE) - sizeof(ApiMessage->Data);
390 
391     /* Fill out the CSR Header */
392     ApiMessage->ApiNumber = ApiNumber;
393     ApiMessage->CsrCaptureData = NULL;
394 
395     DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
396            ApiNumber,
397            ApiMessage->Header.u1.s1.DataLength,
398            ApiMessage->Header.u1.s1.TotalLength);
399 
400     /* Check if we are already inside a CSR Server */
401     if (!InsideCsrProcess)
402     {
403         ULONG PointerCount;
404         PULONG_PTR OffsetPointer;
405 
406         /* Check if we got a Capture Buffer */
407         if (CaptureBuffer)
408         {
409             /*
410              * We have to convert from our local (client) view
411              * to the remote (server) view.
412              */
413             ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
414                 ((ULONG_PTR)CaptureBuffer + CsrPortMemoryDelta);
415 
416             /* Lock the buffer */
417             CaptureBuffer->BufferEnd = NULL;
418 
419             /*
420              * Each client pointer inside the CSR message is converted into
421              * a server pointer, and each pointer to these message pointers
422              * is converted into an offset.
423              */
424             PointerCount  = CaptureBuffer->PointerCount;
425             OffsetPointer = CaptureBuffer->PointerOffsetsArray;
426             while (PointerCount--)
427             {
428                 if (*OffsetPointer != 0)
429                 {
430                     *(PULONG_PTR)*OffsetPointer += CsrPortMemoryDelta;
431                     *OffsetPointer -= (ULONG_PTR)ApiMessage;
432                 }
433                 ++OffsetPointer;
434             }
435         }
436 
437         /* Send the LPC Message */
438         Status = NtRequestWaitReplyPort(CsrApiPort,
439                                         &ApiMessage->Header,
440                                         &ApiMessage->Header);
441 
442         /* Check if we got a Capture Buffer */
443         if (CaptureBuffer)
444         {
445             /*
446              * We have to convert back from the remote (server) view
447              * to our local (client) view.
448              */
449             ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
450                 ((ULONG_PTR)ApiMessage->CsrCaptureData - CsrPortMemoryDelta);
451 
452             /*
453              * Convert back the offsets into pointers to CSR message
454              * pointers, and convert back these message server pointers
455              * into client pointers.
456              */
457             PointerCount  = CaptureBuffer->PointerCount;
458             OffsetPointer = CaptureBuffer->PointerOffsetsArray;
459             while (PointerCount--)
460             {
461                 if (*OffsetPointer != 0)
462                 {
463                     *OffsetPointer += (ULONG_PTR)ApiMessage;
464                     *(PULONG_PTR)*OffsetPointer -= CsrPortMemoryDelta;
465                 }
466                 ++OffsetPointer;
467             }
468         }
469 
470         /* Check for success */
471         if (!NT_SUCCESS(Status))
472         {
473             /* We failed. Overwrite the return value with the failure. */
474             DPRINT1("LPC Failed: %lx\n", Status);
475             ApiMessage->Status = Status;
476         }
477     }
478     else
479     {
480         /* This is a server-to-server call */
481         DPRINT("Server-to-server call\n");
482 
483         /* Save our CID; we check this equality inside CsrValidateMessageBuffer */
484         ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
485 
486         /* Do a direct call */
487         Status = CsrServerApiRoutine(ApiMessage, ApiMessage);
488 
489         /* Check for success */
490         if (!NT_SUCCESS(Status))
491         {
492             /* We failed. Overwrite the return value with the failure. */
493             ApiMessage->Status = Status;
494         }
495     }
496 
497     /* Return the CSR Result */
498     DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
499     return ApiMessage->Status;
500 }
501 
502 /*
503  * @implemented
504  */
505 HANDLE
506 NTAPI
507 CsrGetProcessId(VOID)
508 {
509     return CsrProcessId;
510 }
511 
512 /* EOF */
513