1 /*
2 * PROJECT: ReactOS Client/Server Runtime SubSystem
3 * LICENSE: GPL-2.0-or-later (https://spdx.org/licenses/GPL-2.0-or-later)
4 * PURPOSE: CSR Client Library - CSR connection and calling
5 * COPYRIGHT: Copyright 2005-2013 Alex Ionescu <alex@relsoft.net>
6 * Copyright 2012-2022 Hermès Bélusca-Maïto <hermes.belusca-maito@reactos.org>
7 */
8
9 /* INCLUDES *******************************************************************/
10
11 #include "csrlib.h"
12
13 #define NTOS_MODE_USER
14 #include <ndk/ldrfuncs.h>
15 #include <ndk/lpcfuncs.h>
16 #include <ndk/mmfuncs.h>
17 #include <ndk/obfuncs.h>
18 #include <ndk/umfuncs.h>
19
20 #include <csrsrv.h> // For CSR_CSRSS_SECTION_SIZE
21
22 #define NDEBUG
23 #include <debug.h>
24
25 /* GLOBALS ********************************************************************/
26
27 HANDLE CsrApiPort;
28 HANDLE CsrProcessId;
29 HANDLE CsrPortHeap;
30 ULONG_PTR CsrPortMemoryDelta;
31 BOOLEAN InsideCsrProcess = FALSE;
32
33 typedef NTSTATUS
34 (NTAPI *PCSR_SERVER_API_ROUTINE)(
35 _In_ PCSR_API_MESSAGE Request,
36 _Inout_ PCSR_API_MESSAGE Reply);
37
38 PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;
39
40 /* FUNCTIONS ******************************************************************/
41
42 static NTSTATUS
CsrpConnectToServer(_In_ PCWSTR ObjectDirectory)43 CsrpConnectToServer(
44 _In_ PCWSTR ObjectDirectory)
45 {
46 NTSTATUS Status;
47 SIZE_T PortNameLength;
48 UNICODE_STRING PortName;
49 LARGE_INTEGER CsrSectionViewSize;
50 HANDLE CsrSectionHandle;
51 PORT_VIEW LpcWrite;
52 REMOTE_PORT_VIEW LpcRead;
53 SECURITY_QUALITY_OF_SERVICE SecurityQos;
54 SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};
55 PSID SystemSid = NULL;
56 CSR_API_CONNECTINFO ConnectionInfo;
57 ULONG ConnectionInfoLength = sizeof(ConnectionInfo);
58
59 DPRINT("%s(%S)\n", __FUNCTION__, ObjectDirectory);
60
61 /* Calculate the total port name size */
62 PortNameLength = ((wcslen(ObjectDirectory) + 1) * sizeof(WCHAR)) +
63 sizeof(CSR_PORT_NAME);
64 if (PortNameLength > UNICODE_STRING_MAX_BYTES)
65 {
66 DPRINT1("PortNameLength too big: %Iu\n", PortNameLength);
67 return STATUS_NAME_TOO_LONG;
68 }
69
70 /* Set the port name */
71 PortName.Length = 0;
72 PortName.MaximumLength = (USHORT)PortNameLength;
73
74 /* Allocate a buffer for it */
75 PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), HEAP_ZERO_MEMORY, PortNameLength);
76 if (PortName.Buffer == NULL)
77 {
78 return STATUS_INSUFFICIENT_RESOURCES;
79 }
80
81 /* Create the name */
82 RtlAppendUnicodeToString(&PortName, ObjectDirectory);
83 RtlAppendUnicodeToString(&PortName, L"\\");
84 RtlAppendUnicodeToString(&PortName, CSR_PORT_NAME);
85
86 /* Create a section for the port memory */
87 CsrSectionViewSize.QuadPart = CSR_CSRSS_SECTION_SIZE;
88 Status = NtCreateSection(&CsrSectionHandle,
89 SECTION_ALL_ACCESS,
90 NULL,
91 &CsrSectionViewSize,
92 PAGE_READWRITE,
93 SEC_RESERVE,
94 NULL);
95 if (!NT_SUCCESS(Status))
96 {
97 DPRINT1("Failure allocating CSR Section\n");
98 return Status;
99 }
100
101 /* Set up the port view structures to match them with the section */
102 LpcWrite.Length = sizeof(LpcWrite);
103 LpcWrite.SectionHandle = CsrSectionHandle;
104 LpcWrite.SectionOffset = 0;
105 LpcWrite.ViewSize = CsrSectionViewSize.u.LowPart;
106 LpcWrite.ViewBase = 0;
107 LpcWrite.ViewRemoteBase = 0;
108 LpcRead.Length = sizeof(LpcRead);
109 LpcRead.ViewSize = 0;
110 LpcRead.ViewBase = 0;
111
112 /* Setup the QoS */
113 SecurityQos.ImpersonationLevel = SecurityImpersonation;
114 SecurityQos.ContextTrackingMode = SECURITY_DYNAMIC_TRACKING;
115 SecurityQos.EffectiveOnly = TRUE;
116
117 /* Setup the connection info */
118 ConnectionInfo.DebugFlags = 0;
119
120 /* Create a SID for us */
121 Status = RtlAllocateAndInitializeSid(&NtSidAuthority,
122 1,
123 SECURITY_LOCAL_SYSTEM_RID,
124 0,
125 0,
126 0,
127 0,
128 0,
129 0,
130 0,
131 &SystemSid);
132 if (!NT_SUCCESS(Status))
133 {
134 /* Failure */
135 DPRINT1("Couldn't allocate SID\n");
136 NtClose(CsrSectionHandle);
137 return Status;
138 }
139
140 /* Connect to the port */
141 Status = NtSecureConnectPort(&CsrApiPort,
142 &PortName,
143 &SecurityQos,
144 &LpcWrite,
145 SystemSid,
146 &LpcRead,
147 NULL,
148 &ConnectionInfo,
149 &ConnectionInfoLength);
150 RtlFreeSid(SystemSid);
151 NtClose(CsrSectionHandle);
152 if (!NT_SUCCESS(Status))
153 {
154 /* Failure */
155 DPRINT1("Couldn't connect to CSR port\n");
156 return Status;
157 }
158
159 /* Save the delta between the sections, for capture usage later */
160 CsrPortMemoryDelta = (ULONG_PTR)LpcWrite.ViewRemoteBase -
161 (ULONG_PTR)LpcWrite.ViewBase;
162
163 /* Save the Process */
164 CsrProcessId = ConnectionInfo.ServerProcessId;
165
166 /* Save CSR Section data */
167 NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;
168 NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;
169 NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedStaticServerData;
170
171 /* Create the port heap */
172 CsrPortHeap = RtlCreateHeap(0,
173 LpcWrite.ViewBase,
174 LpcWrite.ViewSize,
175 PAGE_SIZE,
176 0,
177 0);
178 if (CsrPortHeap == NULL)
179 {
180 /* Failure */
181 DPRINT1("Couldn't create heap for CSR port\n");
182 NtClose(CsrApiPort);
183 CsrApiPort = NULL;
184 return STATUS_INSUFFICIENT_RESOURCES;
185 }
186
187 /* Return success */
188 return STATUS_SUCCESS;
189 }
190
191 /*
192 * @implemented
193 */
194 NTSTATUS
195 NTAPI
CsrClientConnectToServer(_In_ PCWSTR ObjectDirectory,_In_ ULONG ServerId,_In_ PVOID ConnectionInfo,_Inout_ PULONG ConnectionInfoSize,_Out_ PBOOLEAN ServerToServerCall)196 CsrClientConnectToServer(
197 _In_ PCWSTR ObjectDirectory,
198 _In_ ULONG ServerId,
199 _In_ PVOID ConnectionInfo,
200 _Inout_ PULONG ConnectionInfoSize,
201 _Out_ PBOOLEAN ServerToServerCall)
202 {
203 NTSTATUS Status;
204 PIMAGE_NT_HEADERS NtHeader;
205
206 DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);
207
208 /* Validate the Connection Info */
209 if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))
210 {
211 DPRINT1("Connection info given, but no length\n");
212 return STATUS_INVALID_PARAMETER;
213 }
214
215 /* Check if we're inside a CSR Process */
216 if (InsideCsrProcess)
217 {
218 /* Tell the client that we're already inside CSR */
219 if (ServerToServerCall) *ServerToServerCall = TRUE;
220 return STATUS_SUCCESS;
221 }
222
223 /*
224 * We might be in a CSR Process but not know it, if this is the first call.
225 * So let's find out.
226 */
227 if (!(NtHeader = RtlImageNtHeader(NtCurrentPeb()->ImageBaseAddress)))
228 {
229 /* The image isn't valid */
230 DPRINT1("Invalid image\n");
231 return STATUS_INVALID_IMAGE_FORMAT;
232 }
233 InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);
234
235 /* Now we can check if we are inside or not */
236 if (InsideCsrProcess)
237 {
238 UNICODE_STRING CsrSrvName;
239 HANDLE hCsrSrv;
240 ANSI_STRING CsrServerRoutineName;
241
242 /* We're inside, so let's find csrsrv */
243 RtlInitUnicodeString(&CsrSrvName, L"csrsrv");
244 Status = LdrGetDllHandle(NULL,
245 NULL,
246 &CsrSrvName,
247 &hCsrSrv);
248
249 /* Now get the Server to Server routine */
250 RtlInitAnsiString(&CsrServerRoutineName, "CsrCallServerFromServer");
251 Status = LdrGetProcedureAddress(hCsrSrv,
252 &CsrServerRoutineName,
253 0L,
254 (PVOID*)&CsrServerApiRoutine);
255
256 /* Use the local heap as port heap */
257 CsrPortHeap = RtlGetProcessHeap();
258
259 /* Tell the caller we're inside the server */
260 if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
261 return STATUS_SUCCESS;
262 }
263
264 /* Now check if connection info is given */
265 if (ConnectionInfo)
266 {
267 CSR_API_MESSAGE ApiMessage;
268 PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.Data.CsrClientConnect;
269 PCSR_CAPTURE_BUFFER CaptureBuffer;
270
271 /* Well, we're definitely in a client now */
272 InsideCsrProcess = FALSE;
273
274 /* Do we have a connection to CSR yet? */
275 if (!CsrApiPort)
276 {
277 /* No, set it up now */
278 Status = CsrpConnectToServer(ObjectDirectory);
279 if (!NT_SUCCESS(Status))
280 {
281 /* Failed */
282 DPRINT1("Failure to connect to CSR\n");
283 return Status;
284 }
285 }
286
287 /* Setup the connect message header */
288 ClientConnect->ServerId = ServerId;
289 ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;
290
291 /* Setup a buffer for the connection info */
292 CaptureBuffer = CsrAllocateCaptureBuffer(1, ClientConnect->ConnectionInfoSize);
293 if (CaptureBuffer == NULL)
294 {
295 return STATUS_INSUFFICIENT_RESOURCES;
296 }
297
298 /* Capture the connection info data */
299 CsrCaptureMessageBuffer(CaptureBuffer,
300 ConnectionInfo,
301 ClientConnect->ConnectionInfoSize,
302 &ClientConnect->ConnectionInfo);
303
304 /* Return the allocated length */
305 *ConnectionInfoSize = ClientConnect->ConnectionInfoSize;
306
307 /* Call CSR */
308 Status = CsrClientCallServer(&ApiMessage,
309 CaptureBuffer,
310 CSR_CREATE_API_NUMBER(CSRSRV_SERVERDLL_INDEX, CsrpClientConnect),
311 sizeof(*ClientConnect));
312
313 /* Copy the updated connection info data back into the user buffer */
314 RtlMoveMemory(ConnectionInfo,
315 ClientConnect->ConnectionInfo,
316 *ConnectionInfoSize);
317
318 /* Free the capture buffer */
319 CsrFreeCaptureBuffer(CaptureBuffer);
320 }
321 else
322 {
323 /* No connection info, just return */
324 Status = STATUS_SUCCESS;
325 }
326
327 /* Let the caller know if this was server to server */
328 DPRINT("Status was: 0x%lx. Are we in server: 0x%x\n", Status, InsideCsrProcess);
329 if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
330
331 return Status;
332 }
333
334 #if 0
335 //
336 // Structures can be padded at the end, causing the size of the entire structure
337 // minus the size of the last field, not to be equal to the offset of the last
338 // field.
339 //
340 typedef struct _TEST_EMBEDDED
341 {
342 ULONG One;
343 ULONG Two;
344 ULONG Three;
345 } TEST_EMBEDDED;
346
347 typedef struct _TEST
348 {
349 PORT_MESSAGE h;
350 TEST_EMBEDDED Three;
351 } TEST;
352
353 C_ASSERT(sizeof(PORT_MESSAGE) == 0x18);
354 C_ASSERT(FIELD_OFFSET(TEST, Three) == 0x18);
355 C_ASSERT(sizeof(TEST_EMBEDDED) == 0xC);
356
357 C_ASSERT(sizeof(TEST) != (sizeof(TEST_EMBEDDED) + sizeof(PORT_MESSAGE)));
358 C_ASSERT((sizeof(TEST) - sizeof(TEST_EMBEDDED)) != FIELD_OFFSET(TEST, Three));
359 #endif
360
361 /*
362 * @implemented
363 */
364 NTSTATUS
365 NTAPI
CsrClientCallServer(_Inout_ PCSR_API_MESSAGE ApiMessage,_Inout_opt_ PCSR_CAPTURE_BUFFER CaptureBuffer,_In_ CSR_API_NUMBER ApiNumber,_In_ ULONG DataLength)366 CsrClientCallServer(
367 _Inout_ PCSR_API_MESSAGE ApiMessage,
368 _Inout_opt_ PCSR_CAPTURE_BUFFER CaptureBuffer,
369 _In_ CSR_API_NUMBER ApiNumber,
370 _In_ ULONG DataLength)
371 {
372 NTSTATUS Status;
373
374 /* Make sure the length is valid */
375 if (DataLength > (MAXSHORT - sizeof(CSR_API_MESSAGE)))
376 {
377 DPRINT1("DataLength too big: %lu\n", DataLength);
378 return STATUS_INVALID_PARAMETER;
379 }
380
381 /* Fill out the Port Message Header */
382 ApiMessage->Header.u2.ZeroInit = 0;
383 /* DataLength = user_data_size + anything between
384 * header and data, including intermediate padding */
385 ApiMessage->Header.u1.s1.DataLength = (CSHORT)DataLength +
386 FIELD_OFFSET(CSR_API_MESSAGE, Data) - sizeof(ApiMessage->Header);
387 /* TotalLength = header_size + DataLength + any structure trailing padding */
388 ApiMessage->Header.u1.s1.TotalLength = (CSHORT)DataLength +
389 sizeof(CSR_API_MESSAGE) - sizeof(ApiMessage->Data);
390
391 /* Fill out the CSR Header */
392 ApiMessage->ApiNumber = ApiNumber;
393 ApiMessage->CsrCaptureData = NULL;
394
395 DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
396 ApiNumber,
397 ApiMessage->Header.u1.s1.DataLength,
398 ApiMessage->Header.u1.s1.TotalLength);
399
400 /* Check if we are already inside a CSR Server */
401 if (!InsideCsrProcess)
402 {
403 ULONG PointerCount;
404 PULONG_PTR OffsetPointer;
405
406 /* Check if we got a Capture Buffer */
407 if (CaptureBuffer)
408 {
409 /*
410 * We have to convert from our local (client) view
411 * to the remote (server) view.
412 */
413 ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
414 ((ULONG_PTR)CaptureBuffer + CsrPortMemoryDelta);
415
416 /* Lock the buffer */
417 CaptureBuffer->BufferEnd = NULL;
418
419 /*
420 * Each client pointer inside the CSR message is converted into
421 * a server pointer, and each pointer to these message pointers
422 * is converted into an offset.
423 */
424 PointerCount = CaptureBuffer->PointerCount;
425 OffsetPointer = CaptureBuffer->PointerOffsetsArray;
426 while (PointerCount--)
427 {
428 if (*OffsetPointer != 0)
429 {
430 *(PULONG_PTR)*OffsetPointer += CsrPortMemoryDelta;
431 *OffsetPointer -= (ULONG_PTR)ApiMessage;
432 }
433 ++OffsetPointer;
434 }
435 }
436
437 /* Send the LPC Message */
438 Status = NtRequestWaitReplyPort(CsrApiPort,
439 &ApiMessage->Header,
440 &ApiMessage->Header);
441
442 /* Check if we got a Capture Buffer */
443 if (CaptureBuffer)
444 {
445 /*
446 * We have to convert back from the remote (server) view
447 * to our local (client) view.
448 */
449 ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
450 ((ULONG_PTR)ApiMessage->CsrCaptureData - CsrPortMemoryDelta);
451
452 /*
453 * Convert back the offsets into pointers to CSR message
454 * pointers, and convert back these message server pointers
455 * into client pointers.
456 */
457 PointerCount = CaptureBuffer->PointerCount;
458 OffsetPointer = CaptureBuffer->PointerOffsetsArray;
459 while (PointerCount--)
460 {
461 if (*OffsetPointer != 0)
462 {
463 *OffsetPointer += (ULONG_PTR)ApiMessage;
464 *(PULONG_PTR)*OffsetPointer -= CsrPortMemoryDelta;
465 }
466 ++OffsetPointer;
467 }
468 }
469
470 /* Check for success */
471 if (!NT_SUCCESS(Status))
472 {
473 /* We failed. Overwrite the return value with the failure. */
474 DPRINT1("LPC Failed: %lx\n", Status);
475 ApiMessage->Status = Status;
476 }
477 }
478 else
479 {
480 /* This is a server-to-server call */
481 DPRINT("Server-to-server call\n");
482
483 /* Save our CID; we check this equality inside CsrValidateMessageBuffer */
484 ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
485
486 /* Do a direct call */
487 Status = CsrServerApiRoutine(ApiMessage, ApiMessage);
488
489 /* Check for success */
490 if (!NT_SUCCESS(Status))
491 {
492 /* We failed. Overwrite the return value with the failure. */
493 ApiMessage->Status = Status;
494 }
495 }
496
497 /* Return the CSR Result */
498 DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
499 return ApiMessage->Status;
500 }
501
502 /*
503 * @implemented
504 */
505 HANDLE
506 NTAPI
CsrGetProcessId(VOID)507 CsrGetProcessId(VOID)
508 {
509 return CsrProcessId;
510 }
511
512 /* EOF */
513