xref: /reactos/dll/win32/ole32/rpc.c (revision 0f5d91b7)
1 /*
2  *	RPC Manager
3  *
4  * Copyright 2001  Ove Kåven, TransGaming Technologies
5  * Copyright 2002  Marcus Meissner
6  * Copyright 2005  Mike Hearn, Rob Shearman for CodeWeavers
7  *
8  * This library is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU Lesser General Public
10  * License as published by the Free Software Foundation; either
11  * version 2.1 of the License, or (at your option) any later version.
12  *
13  * This library is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * Lesser General Public License for more details.
17  *
18  * You should have received a copy of the GNU Lesser General Public
19  * License along with this library; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
21  */
22 
23 #include <stdarg.h>
24 #include <string.h>
25 
26 #define COBJMACROS
27 #define NONAMELESSUNION
28 
29 #include "windef.h"
30 #include "winbase.h"
31 #include "winuser.h"
32 #include "winsvc.h"
33 #include "objbase.h"
34 #include "ole2.h"
35 #include "rpc.h"
36 #include "winerror.h"
37 #include "winreg.h"
38 #include "servprov.h"
39 
40 #include "compobj_private.h"
41 
42 #include "wine/debug.h"
43 
44 WINE_DEFAULT_DEBUG_CHANNEL(ole);
45 
46 static void __RPC_STUB dispatch_rpc(RPC_MESSAGE *msg);
47 
48 /* we only use one function to dispatch calls for all methods - we use the
49  * RPC_IF_OLE flag to tell the RPC runtime that this is the case */
50 static RPC_DISPATCH_FUNCTION rpc_dispatch_table[1] = { dispatch_rpc }; /* (RO) */
51 static RPC_DISPATCH_TABLE rpc_dispatch = { 1, rpc_dispatch_table }; /* (RO) */
52 
53 static struct list registered_interfaces = LIST_INIT(registered_interfaces); /* (CS csRegIf) */
54 static CRITICAL_SECTION csRegIf;
55 static CRITICAL_SECTION_DEBUG csRegIf_debug =
56 {
57     0, 0, &csRegIf,
58     { &csRegIf_debug.ProcessLocksList, &csRegIf_debug.ProcessLocksList },
59       0, 0, { (DWORD_PTR)(__FILE__ ": dcom registered server interfaces") }
60 };
61 static CRITICAL_SECTION csRegIf = { &csRegIf_debug, -1, 0, 0, 0, 0 };
62 
63 static struct list channel_hooks = LIST_INIT(channel_hooks); /* (CS csChannelHook) */
64 static CRITICAL_SECTION csChannelHook;
65 static CRITICAL_SECTION_DEBUG csChannelHook_debug =
66 {
67     0, 0, &csChannelHook,
68     { &csChannelHook_debug.ProcessLocksList, &csChannelHook_debug.ProcessLocksList },
69       0, 0, { (DWORD_PTR)(__FILE__ ": channel hooks") }
70 };
71 static CRITICAL_SECTION csChannelHook = { &csChannelHook_debug, -1, 0, 0, 0, 0 };
72 
73 static WCHAR wszRpcTransport[] = {'n','c','a','l','r','p','c',0};
74 
75 
76 struct registered_if
77 {
78     struct list entry;
79     DWORD refs; /* ref count */
80     RPC_SERVER_INTERFACE If; /* interface registered with the RPC runtime */
81 };
82 
83 /* get the pipe endpoint specified of the specified apartment */
get_rpc_endpoint(LPWSTR endpoint,const OXID * oxid)84 static inline void get_rpc_endpoint(LPWSTR endpoint, const OXID *oxid)
85 {
86     /* FIXME: should get endpoint from rpcss */
87     static const WCHAR wszEndpointFormat[] = {'\\','p','i','p','e','\\','O','L','E','_','%','0','8','l','x','%','0','8','l','x',0};
88     wsprintfW(endpoint, wszEndpointFormat, (DWORD)(*oxid >> 32),(DWORD)*oxid);
89 }
90 
91 typedef struct
92 {
93     IRpcChannelBuffer IRpcChannelBuffer_iface;
94     LONG refs;
95 
96     DWORD dest_context; /* returned from GetDestCtx */
97     void *dest_context_data; /* returned from GetDestCtx */
98 } RpcChannelBuffer;
99 
100 typedef struct
101 {
102     RpcChannelBuffer       super; /* superclass */
103 
104     RPC_BINDING_HANDLE     bind; /* handle to the remote server */
105     OXID                   oxid; /* apartment in which the channel is valid */
106     DWORD                  server_pid; /* id of server process */
107     HANDLE                 event; /* cached event handle */
108     IID                    iid; /* IID of the proxy this belongs to */
109 } ClientRpcChannelBuffer;
110 
111 struct dispatch_params
112 {
113     RPCOLEMESSAGE     *msg; /* message */
114     IRpcStubBuffer    *stub; /* stub buffer, if applicable */
115     IRpcChannelBuffer *chan; /* server channel buffer, if applicable */
116     IID                iid; /* ID of interface being called */
117     IUnknown          *iface; /* interface being called */
118     HANDLE             handle; /* handle that will become signaled when call finishes */
119     BOOL               bypass_rpcrt; /* bypass RPC runtime? */
120     RPC_STATUS         status; /* status (out) */
121     HRESULT            hr; /* hresult (out) */
122 };
123 
124 struct message_state
125 {
126     RPC_BINDING_HANDLE binding_handle;
127     ULONG prefix_data_len;
128     SChannelHookCallInfo channel_hook_info;
129     BOOL bypass_rpcrt;
130 
131     /* client only */
132     HWND target_hwnd;
133     DWORD target_tid;
134     struct dispatch_params params;
135 };
136 
137 typedef struct
138 {
139     ULONG conformance; /* NDR */
140     GUID id;
141     ULONG size;
142     /* [size_is((size+7)&~7)] */ unsigned char data[1];
143 } WIRE_ORPC_EXTENT;
144 
145 typedef struct
146 {
147     ULONG size;
148     ULONG reserved;
149     unsigned char extent[1];
150 } WIRE_ORPC_EXTENT_ARRAY;
151 
152 typedef struct
153 {
154     ULONG version;
155     ULONG flags;
156     ULONG reserved1;
157     GUID  cid;
158     unsigned char extensions[1];
159 } WIRE_ORPCTHIS;
160 
161 typedef struct
162 {
163     ULONG flags;
164     unsigned char extensions[1];
165 } WIRE_ORPCTHAT;
166 
167 struct channel_hook_entry
168 {
169     struct list entry;
170     GUID id;
171     IChannelHook *hook;
172 };
173 
174 struct channel_hook_buffer_data
175 {
176     GUID id;
177     ULONG extension_size;
178 };
179 
180 
181 static HRESULT unmarshal_ORPCTHAT(RPC_MESSAGE *msg, ORPCTHAT *orpcthat,
182                                   ORPC_EXTENT_ARRAY *orpc_ext_array, WIRE_ORPC_EXTENT **first_wire_orpc_extent);
183 
184 /* Channel Hook Functions */
185 
ChannelHooks_ClientGetSize(SChannelHookCallInfo * info,struct channel_hook_buffer_data ** data,unsigned int * hook_count,ULONG * extension_count)186 static ULONG ChannelHooks_ClientGetSize(SChannelHookCallInfo *info,
187     struct channel_hook_buffer_data **data, unsigned int *hook_count,
188     ULONG *extension_count)
189 {
190     struct channel_hook_entry *entry;
191     ULONG total_size = 0;
192     unsigned int hook_index = 0;
193 
194     *hook_count = 0;
195     *extension_count = 0;
196 
197     EnterCriticalSection(&csChannelHook);
198 
199     LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
200         (*hook_count)++;
201 
202     if (*hook_count)
203         *data = HeapAlloc(GetProcessHeap(), 0, *hook_count * sizeof(struct channel_hook_buffer_data));
204     else
205         *data = NULL;
206 
207     LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
208     {
209         ULONG extension_size = 0;
210 
211         IChannelHook_ClientGetSize(entry->hook, &entry->id, &info->iid, &extension_size);
212 
213         TRACE("%s: extension_size = %u\n", debugstr_guid(&entry->id), extension_size);
214 
215         extension_size = (extension_size+7)&~7;
216         (*data)[hook_index].id = entry->id;
217         (*data)[hook_index].extension_size = extension_size;
218 
219         /* an extension is only put onto the wire if it has data to write */
220         if (extension_size)
221         {
222             total_size += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[extension_size]);
223             (*extension_count)++;
224         }
225 
226         hook_index++;
227     }
228 
229     LeaveCriticalSection(&csChannelHook);
230 
231     return total_size;
232 }
233 
ChannelHooks_ClientFillBuffer(SChannelHookCallInfo * info,unsigned char * buffer,struct channel_hook_buffer_data * data,unsigned int hook_count)234 static unsigned char * ChannelHooks_ClientFillBuffer(SChannelHookCallInfo *info,
235     unsigned char *buffer, struct channel_hook_buffer_data *data,
236     unsigned int hook_count)
237 {
238     struct channel_hook_entry *entry;
239 
240     EnterCriticalSection(&csChannelHook);
241 
242     LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
243     {
244         unsigned int i;
245         ULONG extension_size = 0;
246         WIRE_ORPC_EXTENT *wire_orpc_extent = (WIRE_ORPC_EXTENT *)buffer;
247 
248         for (i = 0; i < hook_count; i++)
249             if (IsEqualGUID(&entry->id, &data[i].id))
250                 extension_size = data[i].extension_size;
251 
252         /* an extension is only put onto the wire if it has data to write */
253         if (!extension_size)
254             continue;
255 
256         IChannelHook_ClientFillBuffer(entry->hook, &entry->id, &info->iid,
257             &extension_size, buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]));
258 
259         TRACE("%s: extension_size = %u\n", debugstr_guid(&entry->id), extension_size);
260 
261         /* FIXME: set unused portion of wire_orpc_extent->data to 0? */
262 
263         wire_orpc_extent->conformance = (extension_size+7)&~7;
264         wire_orpc_extent->size = extension_size;
265         wire_orpc_extent->id = entry->id;
266         buffer += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[wire_orpc_extent->conformance]);
267     }
268 
269     LeaveCriticalSection(&csChannelHook);
270 
271     return buffer;
272 }
273 
ChannelHooks_ServerNotify(SChannelHookCallInfo * info,DWORD lDataRep,WIRE_ORPC_EXTENT * first_wire_orpc_extent,ULONG extension_count)274 static void ChannelHooks_ServerNotify(SChannelHookCallInfo *info,
275     DWORD lDataRep, WIRE_ORPC_EXTENT *first_wire_orpc_extent,
276     ULONG extension_count)
277 {
278     struct channel_hook_entry *entry;
279     ULONG i;
280 
281     EnterCriticalSection(&csChannelHook);
282 
283     LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
284     {
285         WIRE_ORPC_EXTENT *wire_orpc_extent;
286         for (i = 0, wire_orpc_extent = first_wire_orpc_extent;
287              i < extension_count;
288              i++, wire_orpc_extent = (WIRE_ORPC_EXTENT *)&wire_orpc_extent->data[wire_orpc_extent->conformance])
289         {
290             if (IsEqualGUID(&entry->id, &wire_orpc_extent->id))
291                 break;
292         }
293         if (i == extension_count) wire_orpc_extent = NULL;
294 
295         IChannelHook_ServerNotify(entry->hook, &entry->id, &info->iid,
296             wire_orpc_extent ? wire_orpc_extent->size : 0,
297             wire_orpc_extent ? wire_orpc_extent->data : NULL,
298             lDataRep);
299     }
300 
301     LeaveCriticalSection(&csChannelHook);
302 }
303 
ChannelHooks_ServerGetSize(SChannelHookCallInfo * info,struct channel_hook_buffer_data ** data,unsigned int * hook_count,ULONG * extension_count)304 static ULONG ChannelHooks_ServerGetSize(SChannelHookCallInfo *info,
305                                         struct channel_hook_buffer_data **data, unsigned int *hook_count,
306                                         ULONG *extension_count)
307 {
308     struct channel_hook_entry *entry;
309     ULONG total_size = 0;
310     unsigned int hook_index = 0;
311 
312     *hook_count = 0;
313     *extension_count = 0;
314 
315     EnterCriticalSection(&csChannelHook);
316 
317     LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
318         (*hook_count)++;
319 
320     if (*hook_count)
321         *data = HeapAlloc(GetProcessHeap(), 0, *hook_count * sizeof(struct channel_hook_buffer_data));
322     else
323         *data = NULL;
324 
325     LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
326     {
327         ULONG extension_size = 0;
328 
329         IChannelHook_ServerGetSize(entry->hook, &entry->id, &info->iid, S_OK,
330                                    &extension_size);
331 
332         TRACE("%s: extension_size = %u\n", debugstr_guid(&entry->id), extension_size);
333 
334         extension_size = (extension_size+7)&~7;
335         (*data)[hook_index].id = entry->id;
336         (*data)[hook_index].extension_size = extension_size;
337 
338         /* an extension is only put onto the wire if it has data to write */
339         if (extension_size)
340         {
341             total_size += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[extension_size]);
342             (*extension_count)++;
343         }
344 
345         hook_index++;
346     }
347 
348     LeaveCriticalSection(&csChannelHook);
349 
350     return total_size;
351 }
352 
ChannelHooks_ServerFillBuffer(SChannelHookCallInfo * info,unsigned char * buffer,struct channel_hook_buffer_data * data,unsigned int hook_count)353 static unsigned char * ChannelHooks_ServerFillBuffer(SChannelHookCallInfo *info,
354                                                      unsigned char *buffer, struct channel_hook_buffer_data *data,
355                                                      unsigned int hook_count)
356 {
357     struct channel_hook_entry *entry;
358 
359     EnterCriticalSection(&csChannelHook);
360 
361     LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
362     {
363         unsigned int i;
364         ULONG extension_size = 0;
365         WIRE_ORPC_EXTENT *wire_orpc_extent = (WIRE_ORPC_EXTENT *)buffer;
366 
367         for (i = 0; i < hook_count; i++)
368             if (IsEqualGUID(&entry->id, &data[i].id))
369                 extension_size = data[i].extension_size;
370 
371         /* an extension is only put onto the wire if it has data to write */
372         if (!extension_size)
373             continue;
374 
375         IChannelHook_ServerFillBuffer(entry->hook, &entry->id, &info->iid,
376                                       &extension_size, buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]),
377                                       S_OK);
378 
379         TRACE("%s: extension_size = %u\n", debugstr_guid(&entry->id), extension_size);
380 
381         /* FIXME: set unused portion of wire_orpc_extent->data to 0? */
382 
383         wire_orpc_extent->conformance = (extension_size+7)&~7;
384         wire_orpc_extent->size = extension_size;
385         wire_orpc_extent->id = entry->id;
386         buffer += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[wire_orpc_extent->conformance]);
387     }
388 
389     LeaveCriticalSection(&csChannelHook);
390 
391     return buffer;
392 }
393 
ChannelHooks_ClientNotify(SChannelHookCallInfo * info,DWORD lDataRep,WIRE_ORPC_EXTENT * first_wire_orpc_extent,ULONG extension_count,HRESULT hrFault)394 static void ChannelHooks_ClientNotify(SChannelHookCallInfo *info,
395                                       DWORD lDataRep, WIRE_ORPC_EXTENT *first_wire_orpc_extent,
396                                       ULONG extension_count, HRESULT hrFault)
397 {
398     struct channel_hook_entry *entry;
399     ULONG i;
400 
401     EnterCriticalSection(&csChannelHook);
402 
403     LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
404     {
405         WIRE_ORPC_EXTENT *wire_orpc_extent;
406         for (i = 0, wire_orpc_extent = first_wire_orpc_extent;
407              i < extension_count;
408              i++, wire_orpc_extent = (WIRE_ORPC_EXTENT *)&wire_orpc_extent->data[wire_orpc_extent->conformance])
409         {
410             if (IsEqualGUID(&entry->id, &wire_orpc_extent->id))
411                 break;
412         }
413         if (i == extension_count) wire_orpc_extent = NULL;
414 
415         IChannelHook_ClientNotify(entry->hook, &entry->id, &info->iid,
416                                   wire_orpc_extent ? wire_orpc_extent->size : 0,
417                                   wire_orpc_extent ? wire_orpc_extent->data : NULL,
418                                   lDataRep, hrFault);
419     }
420 
421     LeaveCriticalSection(&csChannelHook);
422 }
423 
RPC_RegisterChannelHook(REFGUID rguid,IChannelHook * hook)424 HRESULT RPC_RegisterChannelHook(REFGUID rguid, IChannelHook *hook)
425 {
426     struct channel_hook_entry *entry;
427 
428     TRACE("(%s, %p)\n", debugstr_guid(rguid), hook);
429 
430     entry = HeapAlloc(GetProcessHeap(), 0, sizeof(*entry));
431     if (!entry)
432         return E_OUTOFMEMORY;
433 
434     entry->id = *rguid;
435     entry->hook = hook;
436     IChannelHook_AddRef(hook);
437 
438     EnterCriticalSection(&csChannelHook);
439     list_add_tail(&channel_hooks, &entry->entry);
440     LeaveCriticalSection(&csChannelHook);
441 
442     return S_OK;
443 }
444 
RPC_UnregisterAllChannelHooks(void)445 void RPC_UnregisterAllChannelHooks(void)
446 {
447     struct channel_hook_entry *cursor;
448     struct channel_hook_entry *cursor2;
449 
450     EnterCriticalSection(&csChannelHook);
451     LIST_FOR_EACH_ENTRY_SAFE(cursor, cursor2, &channel_hooks, struct channel_hook_entry, entry)
452         HeapFree(GetProcessHeap(), 0, cursor);
453     LeaveCriticalSection(&csChannelHook);
454     DeleteCriticalSection(&csChannelHook);
455     DeleteCriticalSection(&csRegIf);
456 }
457 
458 /* RPC Channel Buffer Functions */
459 
RpcChannelBuffer_QueryInterface(IRpcChannelBuffer * iface,REFIID riid,LPVOID * ppv)460 static HRESULT WINAPI RpcChannelBuffer_QueryInterface(IRpcChannelBuffer *iface, REFIID riid, LPVOID *ppv)
461 {
462     *ppv = NULL;
463     if (IsEqualIID(riid,&IID_IRpcChannelBuffer) || IsEqualIID(riid,&IID_IUnknown))
464     {
465         *ppv = iface;
466         IRpcChannelBuffer_AddRef(iface);
467         return S_OK;
468     }
469     return E_NOINTERFACE;
470 }
471 
RpcChannelBuffer_AddRef(LPRPCCHANNELBUFFER iface)472 static ULONG WINAPI RpcChannelBuffer_AddRef(LPRPCCHANNELBUFFER iface)
473 {
474     RpcChannelBuffer *This = (RpcChannelBuffer *)iface;
475     return InterlockedIncrement(&This->refs);
476 }
477 
ServerRpcChannelBuffer_Release(LPRPCCHANNELBUFFER iface)478 static ULONG WINAPI ServerRpcChannelBuffer_Release(LPRPCCHANNELBUFFER iface)
479 {
480     RpcChannelBuffer *This = (RpcChannelBuffer *)iface;
481     ULONG ref;
482 
483     ref = InterlockedDecrement(&This->refs);
484     if (ref)
485         return ref;
486 
487     HeapFree(GetProcessHeap(), 0, This);
488     return 0;
489 }
490 
ClientRpcChannelBuffer_Release(LPRPCCHANNELBUFFER iface)491 static ULONG WINAPI ClientRpcChannelBuffer_Release(LPRPCCHANNELBUFFER iface)
492 {
493     ClientRpcChannelBuffer *This = (ClientRpcChannelBuffer *)iface;
494     ULONG ref;
495 
496     ref = InterlockedDecrement(&This->super.refs);
497     if (ref)
498         return ref;
499 
500     if (This->event) CloseHandle(This->event);
501     RpcBindingFree(&This->bind);
502     HeapFree(GetProcessHeap(), 0, This);
503     return 0;
504 }
505 
ServerRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,RPCOLEMESSAGE * olemsg,REFIID riid)506 static HRESULT WINAPI ServerRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface, RPCOLEMESSAGE* olemsg, REFIID riid)
507 {
508     RpcChannelBuffer *This = (RpcChannelBuffer *)iface;
509     RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg;
510     RPC_STATUS status;
511     ORPCTHAT *orpcthat;
512     struct message_state *message_state;
513     ULONG extensions_size;
514     struct channel_hook_buffer_data *channel_hook_data;
515     unsigned int channel_hook_count;
516     ULONG extension_count;
517 
518     TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid));
519 
520     message_state = msg->Handle;
521     /* restore the binding handle and the real start of data */
522     msg->Handle = message_state->binding_handle;
523     msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len;
524 
525     extensions_size = ChannelHooks_ServerGetSize(&message_state->channel_hook_info,
526                                                  &channel_hook_data, &channel_hook_count, &extension_count);
527 
528     msg->BufferLength += FIELD_OFFSET(WIRE_ORPCTHAT, extensions) + sizeof(DWORD);
529     if (extensions_size)
530     {
531         msg->BufferLength += FIELD_OFFSET(WIRE_ORPC_EXTENT_ARRAY, extent[2*sizeof(DWORD) + extensions_size]);
532         if (extension_count & 1)
533             msg->BufferLength += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]);
534     }
535 
536     if (message_state->bypass_rpcrt)
537     {
538         msg->Buffer = HeapAlloc(GetProcessHeap(), 0, msg->BufferLength);
539         if (msg->Buffer)
540             status = RPC_S_OK;
541         else
542         {
543             HeapFree(GetProcessHeap(), 0, channel_hook_data);
544             return E_OUTOFMEMORY;
545         }
546     }
547     else
548         status = I_RpcGetBuffer(msg);
549 
550     orpcthat = msg->Buffer;
551     msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(WIRE_ORPCTHAT, extensions);
552 
553     orpcthat->flags = ORPCF_NULL /* FIXME? */;
554 
555     /* NDR representation of orpcthat->extensions */
556     *(DWORD *)msg->Buffer = extensions_size ? 1 : 0;
557     msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
558 
559     if (extensions_size)
560     {
561         WIRE_ORPC_EXTENT_ARRAY *orpc_extent_array = msg->Buffer;
562         orpc_extent_array->size = extension_count;
563         orpc_extent_array->reserved = 0;
564         msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT_ARRAY, extent);
565         /* NDR representation of orpc_extent_array->extent */
566         *(DWORD *)msg->Buffer = 1;
567         msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
568         /* NDR representation of [size_is] attribute of orpc_extent_array->extent */
569         *(DWORD *)msg->Buffer = (extension_count + 1) & ~1;
570         msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
571 
572         msg->Buffer = ChannelHooks_ServerFillBuffer(&message_state->channel_hook_info,
573                                                     msg->Buffer, channel_hook_data, channel_hook_count);
574 
575         /* we must add a dummy extension if there is an odd extension
576          * count to meet the contract specified by the size_is attribute */
577         if (extension_count & 1)
578         {
579             WIRE_ORPC_EXTENT *wire_orpc_extent = msg->Buffer;
580             wire_orpc_extent->conformance = 0;
581             wire_orpc_extent->id = GUID_NULL;
582             wire_orpc_extent->size = 0;
583             msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]);
584         }
585     }
586 
587     HeapFree(GetProcessHeap(), 0, channel_hook_data);
588 
589     /* store the prefixed data length so that we can restore the real buffer
590      * later */
591     message_state->prefix_data_len = (char *)msg->Buffer - (char *)orpcthat;
592     msg->BufferLength -= message_state->prefix_data_len;
593     /* save away the message state again */
594     msg->Handle = message_state;
595 
596     TRACE("-- %d\n", status);
597 
598     return HRESULT_FROM_WIN32(status);
599 }
600 
ClientRpcChannelBuffer_GetEventHandle(ClientRpcChannelBuffer * This)601 static HANDLE ClientRpcChannelBuffer_GetEventHandle(ClientRpcChannelBuffer *This)
602 {
603     HANDLE event = InterlockedExchangePointer(&This->event, NULL);
604 
605     /* Note: must be auto-reset event so we can reuse it without a call
606     * to ResetEvent */
607     if (!event) event = CreateEventW(NULL, FALSE, FALSE, NULL);
608 
609     return event;
610 }
611 
ClientRpcChannelBuffer_ReleaseEventHandle(ClientRpcChannelBuffer * This,HANDLE event)612 static void ClientRpcChannelBuffer_ReleaseEventHandle(ClientRpcChannelBuffer *This, HANDLE event)
613 {
614     if (InterlockedCompareExchangePointer(&This->event, event, NULL))
615         /* already a handle cached in This */
616         CloseHandle(event);
617 }
618 
ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,RPCOLEMESSAGE * olemsg,REFIID riid)619 static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface, RPCOLEMESSAGE* olemsg, REFIID riid)
620 {
621     ClientRpcChannelBuffer *This = (ClientRpcChannelBuffer *)iface;
622     RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg;
623     RPC_CLIENT_INTERFACE *cif;
624     RPC_STATUS status;
625     ORPCTHIS *orpcthis;
626     struct message_state *message_state;
627     ULONG extensions_size;
628     struct channel_hook_buffer_data *channel_hook_data;
629     unsigned int channel_hook_count;
630     ULONG extension_count;
631     IPID ipid;
632     HRESULT hr;
633     APARTMENT *apt = NULL;
634 
635     TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid));
636 
637     cif = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(RPC_CLIENT_INTERFACE));
638     if (!cif)
639         return E_OUTOFMEMORY;
640 
641     message_state = HeapAlloc(GetProcessHeap(), 0, sizeof(*message_state));
642     if (!message_state)
643     {
644         HeapFree(GetProcessHeap(), 0, cif);
645         return E_OUTOFMEMORY;
646     }
647 
648     cif->Length = sizeof(RPC_CLIENT_INTERFACE);
649     /* RPC interface ID = COM interface ID */
650     cif->InterfaceId.SyntaxGUID = This->iid;
651     /* COM objects always have a version of 0.0 */
652     cif->InterfaceId.SyntaxVersion.MajorVersion = 0;
653     cif->InterfaceId.SyntaxVersion.MinorVersion = 0;
654     msg->Handle = This->bind;
655     msg->RpcInterfaceInformation = cif;
656 
657     message_state->prefix_data_len = 0;
658     message_state->binding_handle = This->bind;
659 
660     message_state->channel_hook_info.iid = *riid;
661     message_state->channel_hook_info.cbSize = sizeof(message_state->channel_hook_info);
662     message_state->channel_hook_info.uCausality = COM_CurrentCausalityId();
663     message_state->channel_hook_info.dwServerPid = This->server_pid;
664     message_state->channel_hook_info.iMethod = msg->ProcNum & ~RPC_FLAGS_VALID_BIT;
665     message_state->channel_hook_info.pObject = NULL; /* only present on server-side */
666     message_state->target_hwnd = NULL;
667     message_state->target_tid = 0;
668     memset(&message_state->params, 0, sizeof(message_state->params));
669 
670     extensions_size = ChannelHooks_ClientGetSize(&message_state->channel_hook_info,
671         &channel_hook_data, &channel_hook_count, &extension_count);
672 
673     msg->BufferLength += FIELD_OFFSET(WIRE_ORPCTHIS, extensions) + sizeof(DWORD);
674     if (extensions_size)
675     {
676         msg->BufferLength += FIELD_OFFSET(WIRE_ORPC_EXTENT_ARRAY, extent[2*sizeof(DWORD) + extensions_size]);
677         if (extension_count & 1)
678             msg->BufferLength += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]);
679     }
680 
681     RpcBindingInqObject(message_state->binding_handle, &ipid);
682     hr = ipid_get_dispatch_params(&ipid, &apt, NULL, &message_state->params.stub,
683                                   &message_state->params.chan,
684                                   &message_state->params.iid,
685                                   &message_state->params.iface);
686     if (hr == S_OK)
687     {
688         /* stub, chan, iface and iid are unneeded in multi-threaded case as we go
689          * via the RPC runtime */
690         if (apt->multi_threaded)
691         {
692             IRpcStubBuffer_Release(message_state->params.stub);
693             message_state->params.stub = NULL;
694             IRpcChannelBuffer_Release(message_state->params.chan);
695             message_state->params.chan = NULL;
696             message_state->params.iface = NULL;
697         }
698         else
699         {
700             message_state->params.bypass_rpcrt = TRUE;
701             message_state->target_hwnd = apartment_getwindow(apt);
702             message_state->target_tid = apt->tid;
703             /* we assume later on that this being non-NULL is the indicator that
704              * means call directly instead of going through RPC runtime */
705             if (!message_state->target_hwnd)
706                 ERR("window for apartment %s is NULL\n", wine_dbgstr_longlong(apt->oxid));
707         }
708     }
709     if (apt) apartment_release(apt);
710     message_state->params.handle = ClientRpcChannelBuffer_GetEventHandle(This);
711     /* Note: message_state->params.msg is initialised in
712      * ClientRpcChannelBuffer_SendReceive */
713 
714     /* shortcut the RPC runtime */
715     if (message_state->target_hwnd)
716     {
717         msg->Buffer = HeapAlloc(GetProcessHeap(), 0, msg->BufferLength);
718         if (msg->Buffer)
719             status = RPC_S_OK;
720         else
721             status = ERROR_OUTOFMEMORY;
722     }
723     else
724         status = I_RpcGetBuffer(msg);
725 
726     msg->Handle = message_state;
727 
728     if (status == RPC_S_OK)
729     {
730         orpcthis = msg->Buffer;
731         msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(WIRE_ORPCTHIS, extensions);
732 
733         orpcthis->version.MajorVersion = COM_MAJOR_VERSION;
734         orpcthis->version.MinorVersion = COM_MINOR_VERSION;
735         orpcthis->flags = message_state->channel_hook_info.dwServerPid ? ORPCF_LOCAL : ORPCF_NULL;
736         orpcthis->reserved1 = 0;
737         orpcthis->cid = message_state->channel_hook_info.uCausality;
738 
739         /* NDR representation of orpcthis->extensions */
740         *(DWORD *)msg->Buffer = extensions_size ? 1 : 0;
741         msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
742 
743         if (extensions_size)
744         {
745             ORPC_EXTENT_ARRAY *orpc_extent_array = msg->Buffer;
746             orpc_extent_array->size = extension_count;
747             orpc_extent_array->reserved = 0;
748             msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT_ARRAY, extent);
749             /* NDR representation of orpc_extent_array->extent */
750             *(DWORD *)msg->Buffer = 1;
751             msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
752             /* NDR representation of [size_is] attribute of orpc_extent_array->extent */
753             *(DWORD *)msg->Buffer = (extension_count + 1) & ~1;
754             msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
755 
756             msg->Buffer = ChannelHooks_ClientFillBuffer(&message_state->channel_hook_info,
757                 msg->Buffer, channel_hook_data, channel_hook_count);
758 
759             /* we must add a dummy extension if there is an odd extension
760              * count to meet the contract specified by the size_is attribute */
761             if (extension_count & 1)
762             {
763                 WIRE_ORPC_EXTENT *wire_orpc_extent = msg->Buffer;
764                 wire_orpc_extent->conformance = 0;
765                 wire_orpc_extent->id = GUID_NULL;
766                 wire_orpc_extent->size = 0;
767                 msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]);
768             }
769         }
770 
771         /* store the prefixed data length so that we can restore the real buffer
772          * pointer in ClientRpcChannelBuffer_SendReceive. */
773         message_state->prefix_data_len = (char *)msg->Buffer - (char *)orpcthis;
774         msg->BufferLength -= message_state->prefix_data_len;
775     }
776 
777     HeapFree(GetProcessHeap(), 0, channel_hook_data);
778 
779     TRACE("-- %d\n", status);
780 
781     return HRESULT_FROM_WIN32(status);
782 }
783 
ServerRpcChannelBuffer_SendReceive(LPRPCCHANNELBUFFER iface,RPCOLEMESSAGE * olemsg,ULONG * pstatus)784 static HRESULT WINAPI ServerRpcChannelBuffer_SendReceive(LPRPCCHANNELBUFFER iface, RPCOLEMESSAGE *olemsg, ULONG *pstatus)
785 {
786     FIXME("stub\n");
787     return E_NOTIMPL;
788 }
789 
790 /* this thread runs an outgoing RPC */
rpc_sendreceive_thread(LPVOID param)791 static DWORD WINAPI rpc_sendreceive_thread(LPVOID param)
792 {
793     struct dispatch_params *data = param;
794 
795     /* Note: I_RpcSendReceive doesn't raise exceptions like the higher-level
796      * RPC functions do */
797     data->status = I_RpcSendReceive((RPC_MESSAGE *)data->msg);
798 
799     TRACE("completed with status 0x%x\n", data->status);
800 
801     SetEvent(data->handle);
802 
803     return 0;
804 }
805 
ClientRpcChannelBuffer_IsCorrectApartment(ClientRpcChannelBuffer * This,APARTMENT * apt)806 static inline HRESULT ClientRpcChannelBuffer_IsCorrectApartment(ClientRpcChannelBuffer *This, APARTMENT *apt)
807 {
808     OXID oxid;
809     if (!apt)
810         return S_FALSE;
811     if (apartment_getoxid(apt, &oxid) != S_OK)
812         return S_FALSE;
813     if (This->oxid != oxid)
814         return S_FALSE;
815     return S_OK;
816 }
817 
ClientRpcChannelBuffer_SendReceive(LPRPCCHANNELBUFFER iface,RPCOLEMESSAGE * olemsg,ULONG * pstatus)818 static HRESULT WINAPI ClientRpcChannelBuffer_SendReceive(LPRPCCHANNELBUFFER iface, RPCOLEMESSAGE *olemsg, ULONG *pstatus)
819 {
820     ClientRpcChannelBuffer *This = (ClientRpcChannelBuffer *)iface;
821     HRESULT hr;
822     RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg;
823     RPC_STATUS status;
824     DWORD index;
825     struct message_state *message_state;
826     ORPCTHAT orpcthat;
827     ORPC_EXTENT_ARRAY orpc_ext_array;
828     WIRE_ORPC_EXTENT *first_wire_orpc_extent = NULL;
829     HRESULT hrFault = S_OK;
830     APARTMENT *apt = apartment_get_current_or_mta();
831 
832     TRACE("(%p) iMethod=%d\n", olemsg, olemsg->iMethod);
833 
834     hr = ClientRpcChannelBuffer_IsCorrectApartment(This, apt);
835     if (hr != S_OK)
836     {
837         ERR("called from wrong apartment, should have been 0x%s\n",
838             wine_dbgstr_longlong(This->oxid));
839         if (apt) apartment_release(apt);
840         return RPC_E_WRONG_THREAD;
841     }
842     /* This situation should be impossible in multi-threaded apartments,
843      * because the calling thread isn't re-enterable.
844      * Note: doing a COM call during the processing of a sent message is
845      * only disallowed if a client call is already being waited for
846      * completion */
847     if (!apt->multi_threaded &&
848         COM_CurrentInfo()->pending_call_count_client &&
849         InSendMessage())
850     {
851         ERR("can't make an outgoing COM call in response to a sent message\n");
852         apartment_release(apt);
853         return RPC_E_CANTCALLOUT_ININPUTSYNCCALL;
854     }
855 
856     message_state = msg->Handle;
857     /* restore the binding handle and the real start of data */
858     msg->Handle = message_state->binding_handle;
859     msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len;
860     msg->BufferLength += message_state->prefix_data_len;
861 
862     /* Note: this is an optimization in the Microsoft OLE runtime that we need
863      * to copy, as shown by the test_no_couninitialize_client test. without
864      * short-circuiting the RPC runtime in the case below, the test will
865      * deadlock on the loader lock due to the RPC runtime needing to create
866      * a thread to process the RPC when this function is called indirectly
867      * from DllMain */
868 
869     message_state->params.msg = olemsg;
870     if (message_state->params.bypass_rpcrt)
871     {
872         TRACE("Calling apartment thread 0x%08x...\n", message_state->target_tid);
873 
874         msg->ProcNum &= ~RPC_FLAGS_VALID_BIT;
875 
876         if (!PostMessageW(message_state->target_hwnd, DM_EXECUTERPC, 0,
877                           (LPARAM)&message_state->params))
878         {
879             ERR("PostMessage failed with error %u\n", GetLastError());
880 
881             /* Note: message_state->params.iface doesn't have a reference and
882              * so doesn't need to be released */
883 
884             hr = HRESULT_FROM_WIN32(GetLastError());
885         }
886     }
887     else
888     {
889         /* we use a separate thread here because we need to be able to
890          * pump the message loop in the application thread: if we do not,
891          * any windows created by this thread will hang and RPCs that try
892          * and re-enter this STA from an incoming server thread will
893          * deadlock. InstallShield is an example of that.
894          */
895         if (!QueueUserWorkItem(rpc_sendreceive_thread, &message_state->params, WT_EXECUTEDEFAULT))
896         {
897             ERR("QueueUserWorkItem failed with error %u\n", GetLastError());
898             hr = E_UNEXPECTED;
899         }
900         else
901             hr = S_OK;
902     }
903 
904     if (hr == S_OK)
905     {
906         if (WaitForSingleObject(message_state->params.handle, 0))
907         {
908             COM_CurrentInfo()->pending_call_count_client++;
909             hr = CoWaitForMultipleHandles(0, INFINITE, 1, &message_state->params.handle, &index);
910             COM_CurrentInfo()->pending_call_count_client--;
911         }
912     }
913     ClientRpcChannelBuffer_ReleaseEventHandle(This, message_state->params.handle);
914 
915     /* for WM shortcut, faults are returned in params->hr */
916     if (hr == S_OK)
917         hrFault = message_state->params.hr;
918 
919     status = message_state->params.status;
920 
921     orpcthat.flags = ORPCF_NULL;
922     orpcthat.extensions = NULL;
923 
924     TRACE("RPC call status: 0x%x\n", status);
925     if (status != RPC_S_OK)
926         hr = HRESULT_FROM_WIN32(status);
927 
928     TRACE("hrFault = 0x%08x\n", hrFault);
929 
930     /* FIXME: this condition should be
931      * "hr == S_OK && (!hrFault || msg->BufferLength > FIELD_OFFSET(ORPCTHAT, extensions) + 4)"
932      * but we don't currently reset the message length for PostMessage
933      * dispatched calls */
934     if (hr == S_OK && hrFault == S_OK)
935     {
936         HRESULT hr2;
937         char *original_buffer = msg->Buffer;
938 
939         /* handle ORPCTHAT and client extensions */
940 
941         hr2 = unmarshal_ORPCTHAT(msg, &orpcthat, &orpc_ext_array, &first_wire_orpc_extent);
942         if (FAILED(hr2))
943             hr = hr2;
944 
945         message_state->prefix_data_len = (char *)msg->Buffer - original_buffer;
946         msg->BufferLength -= message_state->prefix_data_len;
947     }
948     else
949         message_state->prefix_data_len = 0;
950 
951     if (hr == S_OK)
952     {
953         ChannelHooks_ClientNotify(&message_state->channel_hook_info,
954                                   msg->DataRepresentation,
955                                   first_wire_orpc_extent,
956                                   orpcthat.extensions && first_wire_orpc_extent ? orpcthat.extensions->size : 0,
957                                   hrFault);
958     }
959 
960     /* save away the message state again */
961     msg->Handle = message_state;
962 
963     if (pstatus) *pstatus = status;
964 
965     if (hr == S_OK)
966         hr = hrFault;
967 
968     TRACE("-- 0x%08x\n", hr);
969 
970     apartment_release(apt);
971     return hr;
972 }
973 
ServerRpcChannelBuffer_FreeBuffer(LPRPCCHANNELBUFFER iface,RPCOLEMESSAGE * olemsg)974 static HRESULT WINAPI ServerRpcChannelBuffer_FreeBuffer(LPRPCCHANNELBUFFER iface, RPCOLEMESSAGE* olemsg)
975 {
976     RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg;
977     RPC_STATUS status;
978     struct message_state *message_state;
979 
980     TRACE("(%p)\n", msg);
981 
982     message_state = msg->Handle;
983     /* restore the binding handle and the real start of data */
984     msg->Handle = message_state->binding_handle;
985     msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len;
986     msg->BufferLength += message_state->prefix_data_len;
987     message_state->prefix_data_len = 0;
988 
989     if (message_state->bypass_rpcrt)
990     {
991         HeapFree(GetProcessHeap(), 0, msg->Buffer);
992         status = RPC_S_OK;
993     }
994     else
995         status = I_RpcFreeBuffer(msg);
996 
997     msg->Handle = message_state;
998 
999     TRACE("-- %d\n", status);
1000 
1001     return HRESULT_FROM_WIN32(status);
1002 }
1003 
ClientRpcChannelBuffer_FreeBuffer(LPRPCCHANNELBUFFER iface,RPCOLEMESSAGE * olemsg)1004 static HRESULT WINAPI ClientRpcChannelBuffer_FreeBuffer(LPRPCCHANNELBUFFER iface, RPCOLEMESSAGE* olemsg)
1005 {
1006     RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg;
1007     RPC_STATUS status;
1008     struct message_state *message_state;
1009 
1010     TRACE("(%p)\n", msg);
1011 
1012     message_state = msg->Handle;
1013     /* restore the binding handle and the real start of data */
1014     msg->Handle = message_state->binding_handle;
1015     msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len;
1016     msg->BufferLength += message_state->prefix_data_len;
1017 
1018     if (message_state->params.bypass_rpcrt)
1019     {
1020         HeapFree(GetProcessHeap(), 0, msg->Buffer);
1021         status = RPC_S_OK;
1022     }
1023     else
1024         status = I_RpcFreeBuffer(msg);
1025 
1026     HeapFree(GetProcessHeap(), 0, msg->RpcInterfaceInformation);
1027     msg->RpcInterfaceInformation = NULL;
1028 
1029     if (message_state->params.stub)
1030         IRpcStubBuffer_Release(message_state->params.stub);
1031     if (message_state->params.chan)
1032         IRpcChannelBuffer_Release(message_state->params.chan);
1033     HeapFree(GetProcessHeap(), 0, message_state);
1034 
1035     TRACE("-- %d\n", status);
1036 
1037     return HRESULT_FROM_WIN32(status);
1038 }
1039 
ClientRpcChannelBuffer_GetDestCtx(LPRPCCHANNELBUFFER iface,DWORD * pdwDestContext,void ** ppvDestContext)1040 static HRESULT WINAPI ClientRpcChannelBuffer_GetDestCtx(LPRPCCHANNELBUFFER iface, DWORD* pdwDestContext, void** ppvDestContext)
1041 {
1042     ClientRpcChannelBuffer *This = (ClientRpcChannelBuffer *)iface;
1043 
1044     TRACE("(%p,%p)\n", pdwDestContext, ppvDestContext);
1045 
1046     *pdwDestContext = This->super.dest_context;
1047     *ppvDestContext = This->super.dest_context_data;
1048 
1049     return S_OK;
1050 }
1051 
ServerRpcChannelBuffer_GetDestCtx(LPRPCCHANNELBUFFER iface,DWORD * dest_context,void ** dest_context_data)1052 static HRESULT WINAPI ServerRpcChannelBuffer_GetDestCtx(LPRPCCHANNELBUFFER iface, DWORD* dest_context, void** dest_context_data)
1053 {
1054     RpcChannelBuffer *This = (RpcChannelBuffer *)iface;
1055 
1056     TRACE("(%p,%p)\n", dest_context, dest_context_data);
1057 
1058     *dest_context = This->dest_context;
1059     *dest_context_data = This->dest_context_data;
1060     return S_OK;
1061 }
1062 
RpcChannelBuffer_IsConnected(LPRPCCHANNELBUFFER iface)1063 static HRESULT WINAPI RpcChannelBuffer_IsConnected(LPRPCCHANNELBUFFER iface)
1064 {
1065     TRACE("()\n");
1066     /* native does nothing too */
1067     return S_OK;
1068 }
1069 
1070 static const IRpcChannelBufferVtbl ClientRpcChannelBufferVtbl =
1071 {
1072     RpcChannelBuffer_QueryInterface,
1073     RpcChannelBuffer_AddRef,
1074     ClientRpcChannelBuffer_Release,
1075     ClientRpcChannelBuffer_GetBuffer,
1076     ClientRpcChannelBuffer_SendReceive,
1077     ClientRpcChannelBuffer_FreeBuffer,
1078     ClientRpcChannelBuffer_GetDestCtx,
1079     RpcChannelBuffer_IsConnected
1080 };
1081 
1082 static const IRpcChannelBufferVtbl ServerRpcChannelBufferVtbl =
1083 {
1084     RpcChannelBuffer_QueryInterface,
1085     RpcChannelBuffer_AddRef,
1086     ServerRpcChannelBuffer_Release,
1087     ServerRpcChannelBuffer_GetBuffer,
1088     ServerRpcChannelBuffer_SendReceive,
1089     ServerRpcChannelBuffer_FreeBuffer,
1090     ServerRpcChannelBuffer_GetDestCtx,
1091     RpcChannelBuffer_IsConnected
1092 };
1093 
1094 /* returns a channel buffer for proxies */
RPC_CreateClientChannel(const OXID * oxid,const IPID * ipid,const OXID_INFO * oxid_info,const IID * iid,DWORD dest_context,void * dest_context_data,IRpcChannelBuffer ** chan,APARTMENT * apt)1095 HRESULT RPC_CreateClientChannel(const OXID *oxid, const IPID *ipid,
1096                                 const OXID_INFO *oxid_info, const IID *iid,
1097                                 DWORD dest_context, void *dest_context_data,
1098                                 IRpcChannelBuffer **chan, APARTMENT *apt)
1099 {
1100     ClientRpcChannelBuffer *This;
1101     WCHAR                   endpoint[200];
1102     RPC_BINDING_HANDLE      bind;
1103     RPC_STATUS              status;
1104     LPWSTR                  string_binding;
1105 
1106     /* FIXME: get the endpoint from oxid_info->psa instead */
1107     get_rpc_endpoint(endpoint, oxid);
1108 
1109     TRACE("proxy pipe: connecting to endpoint: %s\n", debugstr_w(endpoint));
1110 
1111     status = RpcStringBindingComposeW(
1112         NULL,
1113         wszRpcTransport,
1114         NULL,
1115         endpoint,
1116         NULL,
1117         &string_binding);
1118 
1119     if (status == RPC_S_OK)
1120     {
1121         status = RpcBindingFromStringBindingW(string_binding, &bind);
1122 
1123         if (status == RPC_S_OK)
1124         {
1125             IPID ipid2 = *ipid; /* why can't RpcBindingSetObject take a const? */
1126             status = RpcBindingSetObject(bind, &ipid2);
1127             if (status != RPC_S_OK)
1128                 RpcBindingFree(&bind);
1129         }
1130 
1131         RpcStringFreeW(&string_binding);
1132     }
1133 
1134     if (status != RPC_S_OK)
1135     {
1136         ERR("Couldn't get binding for endpoint %s, status = %d\n", debugstr_w(endpoint), status);
1137         return HRESULT_FROM_WIN32(status);
1138     }
1139 
1140     This = HeapAlloc(GetProcessHeap(), 0, sizeof(*This));
1141     if (!This)
1142     {
1143         RpcBindingFree(&bind);
1144         return E_OUTOFMEMORY;
1145     }
1146 
1147     This->super.IRpcChannelBuffer_iface.lpVtbl = &ClientRpcChannelBufferVtbl;
1148     This->super.refs = 1;
1149     This->super.dest_context = dest_context;
1150     This->super.dest_context_data = dest_context_data;
1151     This->bind = bind;
1152     apartment_getoxid(apt, &This->oxid);
1153     This->server_pid = oxid_info->dwPid;
1154     This->event = NULL;
1155     This->iid = *iid;
1156 
1157     *chan = &This->super.IRpcChannelBuffer_iface;
1158 
1159     return S_OK;
1160 }
1161 
RPC_CreateServerChannel(DWORD dest_context,void * dest_context_data,IRpcChannelBuffer ** chan)1162 HRESULT RPC_CreateServerChannel(DWORD dest_context, void *dest_context_data, IRpcChannelBuffer **chan)
1163 {
1164     RpcChannelBuffer *This = HeapAlloc(GetProcessHeap(), 0, sizeof(*This));
1165     if (!This)
1166         return E_OUTOFMEMORY;
1167 
1168     This->IRpcChannelBuffer_iface.lpVtbl = &ServerRpcChannelBufferVtbl;
1169     This->refs = 1;
1170     This->dest_context = dest_context;
1171     This->dest_context_data = dest_context_data;
1172 
1173     *chan = &This->IRpcChannelBuffer_iface;
1174 
1175     return S_OK;
1176 }
1177 
1178 /* unmarshals ORPC_EXTENT_ARRAY according to NDR rules, but doesn't allocate
1179  * any memory */
unmarshal_ORPC_EXTENT_ARRAY(RPC_MESSAGE * msg,const char * end,ORPC_EXTENT_ARRAY * extensions,WIRE_ORPC_EXTENT ** first_wire_orpc_extent)1180 static HRESULT unmarshal_ORPC_EXTENT_ARRAY(RPC_MESSAGE *msg, const char *end,
1181                                            ORPC_EXTENT_ARRAY *extensions,
1182                                            WIRE_ORPC_EXTENT **first_wire_orpc_extent)
1183 {
1184     DWORD pointer_id;
1185     DWORD i;
1186 
1187     memcpy(extensions, msg->Buffer, FIELD_OFFSET(WIRE_ORPC_EXTENT_ARRAY, extent));
1188     msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT_ARRAY, extent);
1189 
1190     if ((const char *)msg->Buffer + 2 * sizeof(DWORD) > end)
1191         return RPC_E_INVALID_HEADER;
1192 
1193     pointer_id = *(DWORD *)msg->Buffer;
1194     msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
1195     extensions->extent = NULL;
1196 
1197     if (pointer_id)
1198     {
1199         WIRE_ORPC_EXTENT *wire_orpc_extent;
1200 
1201         /* conformance */
1202         if (*(DWORD *)msg->Buffer != ((extensions->size+1)&~1))
1203             return RPC_S_INVALID_BOUND;
1204 
1205         msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
1206 
1207         /* arbitrary limit for security (don't know what native does) */
1208         if (extensions->size > 256)
1209         {
1210             ERR("too many extensions: %d\n", extensions->size);
1211             return RPC_S_INVALID_BOUND;
1212         }
1213 
1214         *first_wire_orpc_extent = wire_orpc_extent = msg->Buffer;
1215         for (i = 0; i < ((extensions->size+1)&~1); i++)
1216         {
1217             if ((const char *)&wire_orpc_extent->data[0] > end)
1218                 return RPC_S_INVALID_BOUND;
1219             if (wire_orpc_extent->conformance != ((wire_orpc_extent->size+7)&~7))
1220                 return RPC_S_INVALID_BOUND;
1221             if ((const char *)&wire_orpc_extent->data[wire_orpc_extent->conformance] > end)
1222                 return RPC_S_INVALID_BOUND;
1223             TRACE("size %u, guid %s\n", wire_orpc_extent->size, debugstr_guid(&wire_orpc_extent->id));
1224             wire_orpc_extent = (WIRE_ORPC_EXTENT *)&wire_orpc_extent->data[wire_orpc_extent->conformance];
1225         }
1226         msg->Buffer = wire_orpc_extent;
1227     }
1228 
1229     return S_OK;
1230 }
1231 
1232 /* unmarshals ORPCTHIS according to NDR rules, but doesn't allocate any memory */
unmarshal_ORPCTHIS(RPC_MESSAGE * msg,ORPCTHIS * orpcthis,ORPC_EXTENT_ARRAY * orpc_ext_array,WIRE_ORPC_EXTENT ** first_wire_orpc_extent)1233 static HRESULT unmarshal_ORPCTHIS(RPC_MESSAGE *msg, ORPCTHIS *orpcthis,
1234     ORPC_EXTENT_ARRAY *orpc_ext_array, WIRE_ORPC_EXTENT **first_wire_orpc_extent)
1235 {
1236     const char *end = (char *)msg->Buffer + msg->BufferLength;
1237 
1238     *first_wire_orpc_extent = NULL;
1239 
1240     if (msg->BufferLength < FIELD_OFFSET(WIRE_ORPCTHIS, extensions) + sizeof(DWORD))
1241     {
1242         ERR("invalid buffer length\n");
1243         return RPC_E_INVALID_HEADER;
1244     }
1245 
1246     memcpy(orpcthis, msg->Buffer, FIELD_OFFSET(WIRE_ORPCTHIS, extensions));
1247     msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(WIRE_ORPCTHIS, extensions);
1248 
1249     if ((const char *)msg->Buffer + sizeof(DWORD) > end)
1250         return RPC_E_INVALID_HEADER;
1251 
1252     if (*(DWORD *)msg->Buffer)
1253         orpcthis->extensions = orpc_ext_array;
1254     else
1255         orpcthis->extensions = NULL;
1256 
1257     msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
1258 
1259     if (orpcthis->extensions)
1260     {
1261         HRESULT hr = unmarshal_ORPC_EXTENT_ARRAY(msg, end, orpc_ext_array,
1262                                                  first_wire_orpc_extent);
1263         if (FAILED(hr))
1264             return hr;
1265     }
1266 
1267     if ((orpcthis->version.MajorVersion != COM_MAJOR_VERSION) ||
1268         (orpcthis->version.MinorVersion > COM_MINOR_VERSION))
1269     {
1270         ERR("COM version {%d, %d} not supported\n",
1271             orpcthis->version.MajorVersion, orpcthis->version.MinorVersion);
1272         return RPC_E_VERSION_MISMATCH;
1273     }
1274 
1275     if (orpcthis->flags & ~(ORPCF_LOCAL|ORPCF_RESERVED1|ORPCF_RESERVED2|ORPCF_RESERVED3|ORPCF_RESERVED4))
1276     {
1277         ERR("invalid flags 0x%x\n", orpcthis->flags & ~(ORPCF_LOCAL|ORPCF_RESERVED1|ORPCF_RESERVED2|ORPCF_RESERVED3|ORPCF_RESERVED4));
1278         return RPC_E_INVALID_HEADER;
1279     }
1280 
1281     return S_OK;
1282 }
1283 
unmarshal_ORPCTHAT(RPC_MESSAGE * msg,ORPCTHAT * orpcthat,ORPC_EXTENT_ARRAY * orpc_ext_array,WIRE_ORPC_EXTENT ** first_wire_orpc_extent)1284 static HRESULT unmarshal_ORPCTHAT(RPC_MESSAGE *msg, ORPCTHAT *orpcthat,
1285                                   ORPC_EXTENT_ARRAY *orpc_ext_array, WIRE_ORPC_EXTENT **first_wire_orpc_extent)
1286 {
1287     const char *end = (char *)msg->Buffer + msg->BufferLength;
1288 
1289     *first_wire_orpc_extent = NULL;
1290 
1291     if (msg->BufferLength < FIELD_OFFSET(WIRE_ORPCTHAT, extensions) + sizeof(DWORD))
1292     {
1293         ERR("invalid buffer length\n");
1294         return RPC_E_INVALID_HEADER;
1295     }
1296 
1297     memcpy(orpcthat, msg->Buffer, FIELD_OFFSET(WIRE_ORPCTHAT, extensions));
1298     msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(WIRE_ORPCTHAT, extensions);
1299 
1300     if ((const char *)msg->Buffer + sizeof(DWORD) > end)
1301         return RPC_E_INVALID_HEADER;
1302 
1303     if (*(DWORD *)msg->Buffer)
1304         orpcthat->extensions = orpc_ext_array;
1305     else
1306         orpcthat->extensions = NULL;
1307 
1308     msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
1309 
1310     if (orpcthat->extensions)
1311     {
1312         HRESULT hr = unmarshal_ORPC_EXTENT_ARRAY(msg, end, orpc_ext_array,
1313                                                  first_wire_orpc_extent);
1314         if (FAILED(hr))
1315             return hr;
1316     }
1317 
1318     if (orpcthat->flags & ~(ORPCF_LOCAL|ORPCF_RESERVED1|ORPCF_RESERVED2|ORPCF_RESERVED3|ORPCF_RESERVED4))
1319     {
1320         ERR("invalid flags 0x%x\n", orpcthat->flags & ~(ORPCF_LOCAL|ORPCF_RESERVED1|ORPCF_RESERVED2|ORPCF_RESERVED3|ORPCF_RESERVED4));
1321         return RPC_E_INVALID_HEADER;
1322     }
1323 
1324     return S_OK;
1325 }
1326 
RPC_ExecuteCall(struct dispatch_params * params)1327 void RPC_ExecuteCall(struct dispatch_params *params)
1328 {
1329     struct message_state *message_state = NULL;
1330     RPC_MESSAGE *msg = (RPC_MESSAGE *)params->msg;
1331     char *original_buffer = msg->Buffer;
1332     ORPCTHIS orpcthis;
1333     ORPC_EXTENT_ARRAY orpc_ext_array;
1334     WIRE_ORPC_EXTENT *first_wire_orpc_extent;
1335     GUID old_causality_id;
1336 
1337     /* handle ORPCTHIS and server extensions */
1338 
1339     params->hr = unmarshal_ORPCTHIS(msg, &orpcthis, &orpc_ext_array, &first_wire_orpc_extent);
1340     if (params->hr != S_OK)
1341     {
1342         msg->Buffer = original_buffer;
1343         goto exit;
1344     }
1345 
1346     message_state = HeapAlloc(GetProcessHeap(), 0, sizeof(*message_state));
1347     if (!message_state)
1348     {
1349         params->hr = E_OUTOFMEMORY;
1350         msg->Buffer = original_buffer;
1351         goto exit;
1352     }
1353 
1354     message_state->prefix_data_len = (char *)msg->Buffer - original_buffer;
1355     message_state->binding_handle = msg->Handle;
1356     message_state->bypass_rpcrt = params->bypass_rpcrt;
1357 
1358     message_state->channel_hook_info.iid = params->iid;
1359     message_state->channel_hook_info.cbSize = sizeof(message_state->channel_hook_info);
1360     message_state->channel_hook_info.uCausality = orpcthis.cid;
1361     message_state->channel_hook_info.dwServerPid = GetCurrentProcessId();
1362     message_state->channel_hook_info.iMethod = msg->ProcNum;
1363     message_state->channel_hook_info.pObject = params->iface;
1364 
1365     if (orpcthis.extensions && first_wire_orpc_extent &&
1366         orpcthis.extensions->size)
1367         ChannelHooks_ServerNotify(&message_state->channel_hook_info, msg->DataRepresentation, first_wire_orpc_extent, orpcthis.extensions->size);
1368 
1369     msg->Handle = message_state;
1370     msg->BufferLength -= message_state->prefix_data_len;
1371 
1372     /* call message filter */
1373 
1374     if (COM_CurrentApt()->filter)
1375     {
1376         DWORD handlecall;
1377         INTERFACEINFO interface_info;
1378         CALLTYPE calltype;
1379 
1380         interface_info.pUnk = params->iface;
1381         interface_info.iid = params->iid;
1382         interface_info.wMethod = msg->ProcNum;
1383 
1384         if (IsEqualGUID(&orpcthis.cid, &COM_CurrentInfo()->causality_id))
1385             calltype = CALLTYPE_NESTED;
1386         else if (COM_CurrentInfo()->pending_call_count_server == 0)
1387             calltype = CALLTYPE_TOPLEVEL;
1388         else
1389             calltype = CALLTYPE_TOPLEVEL_CALLPENDING;
1390 
1391         handlecall = IMessageFilter_HandleInComingCall(COM_CurrentApt()->filter,
1392                                                        calltype,
1393                                                        UlongToHandle(GetCurrentProcessId()),
1394                                                        0 /* FIXME */,
1395                                                        &interface_info);
1396         TRACE("IMessageFilter_HandleInComingCall returned %d\n", handlecall);
1397         switch (handlecall)
1398         {
1399         case SERVERCALL_REJECTED:
1400             params->hr = RPC_E_CALL_REJECTED;
1401             goto exit_reset_state;
1402         case SERVERCALL_RETRYLATER:
1403 #if 0 /* FIXME: handle retries on the client side before enabling this code */
1404             params->hr = RPC_E_RETRY;
1405             goto exit_reset_state;
1406 #else
1407             FIXME("retry call later not implemented\n");
1408             break;
1409 #endif
1410         case SERVERCALL_ISHANDLED:
1411         default:
1412             break;
1413         }
1414     }
1415 
1416     /* invoke the method */
1417 
1418     /* save the old causality ID - note: any calls executed while processing
1419      * messages received during the SendReceive will appear to originate from
1420      * this call - this should be checked with what Windows does */
1421     old_causality_id = COM_CurrentInfo()->causality_id;
1422     COM_CurrentInfo()->causality_id = orpcthis.cid;
1423     COM_CurrentInfo()->pending_call_count_server++;
1424     params->hr = IRpcStubBuffer_Invoke(params->stub, params->msg, params->chan);
1425     COM_CurrentInfo()->pending_call_count_server--;
1426     COM_CurrentInfo()->causality_id = old_causality_id;
1427 
1428     /* the invoke allocated a new buffer, so free the old one */
1429     if (message_state->bypass_rpcrt && original_buffer != msg->Buffer)
1430         HeapFree(GetProcessHeap(), 0, original_buffer);
1431 
1432 exit_reset_state:
1433     message_state = msg->Handle;
1434     msg->Handle = message_state->binding_handle;
1435     msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len;
1436     msg->BufferLength += message_state->prefix_data_len;
1437 
1438 exit:
1439     HeapFree(GetProcessHeap(), 0, message_state);
1440     if (params->handle) SetEvent(params->handle);
1441 }
1442 
dispatch_rpc(RPC_MESSAGE * msg)1443 static void __RPC_STUB dispatch_rpc(RPC_MESSAGE *msg)
1444 {
1445     struct dispatch_params *params;
1446     struct stub_manager *stub_manager;
1447     APARTMENT *apt;
1448     IPID ipid;
1449     HRESULT hr;
1450 
1451     RpcBindingInqObject(msg->Handle, &ipid);
1452 
1453     TRACE("ipid = %s, iMethod = %d\n", debugstr_guid(&ipid), msg->ProcNum);
1454 
1455     params = HeapAlloc(GetProcessHeap(), 0, sizeof(*params));
1456     if (!params)
1457     {
1458         RpcRaiseException(E_OUTOFMEMORY);
1459         return;
1460     }
1461 
1462     hr = ipid_get_dispatch_params(&ipid, &apt, &stub_manager, &params->stub, &params->chan,
1463                                   &params->iid, &params->iface);
1464     if (hr != S_OK)
1465     {
1466         ERR("no apartment found for ipid %s\n", debugstr_guid(&ipid));
1467         HeapFree(GetProcessHeap(), 0, params);
1468         RpcRaiseException(hr);
1469         return;
1470     }
1471 
1472     params->msg = (RPCOLEMESSAGE *)msg;
1473     params->status = RPC_S_OK;
1474     params->hr = S_OK;
1475     params->handle = NULL;
1476     params->bypass_rpcrt = FALSE;
1477 
1478     /* Note: this is the important difference between STAs and MTAs - we
1479      * always execute RPCs to STAs in the thread that originally created the
1480      * apartment (i.e. the one that pumps messages to the window) */
1481     if (!apt->multi_threaded)
1482     {
1483         params->handle = CreateEventW(NULL, FALSE, FALSE, NULL);
1484 
1485         TRACE("Calling apartment thread 0x%08x...\n", apt->tid);
1486 
1487         if (PostMessageW(apartment_getwindow(apt), DM_EXECUTERPC, 0, (LPARAM)params))
1488             WaitForSingleObject(params->handle, INFINITE);
1489         else
1490         {
1491             ERR("PostMessage failed with error %u\n", GetLastError());
1492             IRpcChannelBuffer_Release(params->chan);
1493             IRpcStubBuffer_Release(params->stub);
1494         }
1495         CloseHandle(params->handle);
1496     }
1497     else
1498     {
1499         BOOL joined = FALSE;
1500         struct oletls *info = COM_CurrentInfo();
1501 
1502         if (!info->apt)
1503         {
1504             enter_apartment(info, COINIT_MULTITHREADED);
1505             joined = TRUE;
1506         }
1507         RPC_ExecuteCall(params);
1508         if (joined)
1509         {
1510             leave_apartment(info);
1511         }
1512     }
1513 
1514     hr = params->hr;
1515     if (params->chan)
1516         IRpcChannelBuffer_Release(params->chan);
1517     if (params->stub)
1518         IRpcStubBuffer_Release(params->stub);
1519     HeapFree(GetProcessHeap(), 0, params);
1520 
1521     stub_manager_int_release(stub_manager);
1522     apartment_release(apt);
1523 
1524     /* if IRpcStubBuffer_Invoke fails, we should raise an exception to tell
1525      * the RPC runtime that the call failed */
1526     if (hr != S_OK) RpcRaiseException(hr);
1527 }
1528 
1529 /* stub registration */
RPC_RegisterInterface(REFIID riid)1530 HRESULT RPC_RegisterInterface(REFIID riid)
1531 {
1532     struct registered_if *rif;
1533     BOOL found = FALSE;
1534     HRESULT hr = S_OK;
1535 
1536     TRACE("(%s)\n", debugstr_guid(riid));
1537 
1538     EnterCriticalSection(&csRegIf);
1539     LIST_FOR_EACH_ENTRY(rif, &registered_interfaces, struct registered_if, entry)
1540     {
1541         if (IsEqualGUID(&rif->If.InterfaceId.SyntaxGUID, riid))
1542         {
1543             rif->refs++;
1544             found = TRUE;
1545             break;
1546         }
1547     }
1548     if (!found)
1549     {
1550         TRACE("Creating new interface\n");
1551 
1552         rif = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(*rif));
1553         if (rif)
1554         {
1555             RPC_STATUS status;
1556 
1557             rif->refs = 1;
1558             rif->If.Length = sizeof(RPC_SERVER_INTERFACE);
1559             /* RPC interface ID = COM interface ID */
1560             rif->If.InterfaceId.SyntaxGUID = *riid;
1561             rif->If.DispatchTable = &rpc_dispatch;
1562             /* all other fields are 0, including the version asCOM objects
1563              * always have a version of 0.0 */
1564             status = RpcServerRegisterIfEx(
1565                 (RPC_IF_HANDLE)&rif->If,
1566                 NULL, NULL,
1567                 RPC_IF_OLE | RPC_IF_AUTOLISTEN,
1568                 RPC_C_LISTEN_MAX_CALLS_DEFAULT,
1569                 NULL);
1570             if (status == RPC_S_OK)
1571                 list_add_tail(&registered_interfaces, &rif->entry);
1572             else
1573             {
1574                 ERR("RpcServerRegisterIfEx failed with error %d\n", status);
1575                 HeapFree(GetProcessHeap(), 0, rif);
1576                 hr = HRESULT_FROM_WIN32(status);
1577             }
1578         }
1579         else
1580             hr = E_OUTOFMEMORY;
1581     }
1582     LeaveCriticalSection(&csRegIf);
1583     return hr;
1584 }
1585 
1586 /* stub unregistration */
RPC_UnregisterInterface(REFIID riid,BOOL wait)1587 void RPC_UnregisterInterface(REFIID riid, BOOL wait)
1588 {
1589     struct registered_if *rif;
1590     EnterCriticalSection(&csRegIf);
1591     LIST_FOR_EACH_ENTRY(rif, &registered_interfaces, struct registered_if, entry)
1592     {
1593         if (IsEqualGUID(&rif->If.InterfaceId.SyntaxGUID, riid))
1594         {
1595             if (!--rif->refs)
1596             {
1597                 RpcServerUnregisterIf((RPC_IF_HANDLE)&rif->If, NULL, wait);
1598                 list_remove(&rif->entry);
1599                 HeapFree(GetProcessHeap(), 0, rif);
1600             }
1601             break;
1602         }
1603     }
1604     LeaveCriticalSection(&csRegIf);
1605 }
1606 
1607 /* get the info for an OXID, including the IPID for the rem unknown interface
1608  * and the string binding */
RPC_ResolveOxid(OXID oxid,OXID_INFO * oxid_info)1609 HRESULT RPC_ResolveOxid(OXID oxid, OXID_INFO *oxid_info)
1610 {
1611     TRACE("%s\n", wine_dbgstr_longlong(oxid));
1612 
1613     oxid_info->dwTid = 0;
1614     oxid_info->dwPid = 0;
1615     oxid_info->dwAuthnHint = RPC_C_AUTHN_LEVEL_NONE;
1616     /* FIXME: this is a hack around not having an OXID resolver yet -
1617      * this function should contact the machine's OXID resolver and then it
1618      * should give us the IPID of the IRemUnknown interface */
1619     oxid_info->ipidRemUnknown.Data1 = 0xffffffff;
1620     oxid_info->ipidRemUnknown.Data2 = 0xffff;
1621     oxid_info->ipidRemUnknown.Data3 = 0xffff;
1622     memcpy(oxid_info->ipidRemUnknown.Data4, &oxid, sizeof(OXID));
1623     oxid_info->psa = NULL /* FIXME */;
1624 
1625     return S_OK;
1626 }
1627 
1628 /* make the apartment reachable by other threads and processes and create the
1629  * IRemUnknown object */
RPC_StartRemoting(struct apartment * apt)1630 void RPC_StartRemoting(struct apartment *apt)
1631 {
1632     if (!InterlockedExchange(&apt->remoting_started, TRUE))
1633     {
1634         WCHAR endpoint[200];
1635         RPC_STATUS status;
1636 
1637         get_rpc_endpoint(endpoint, &apt->oxid);
1638 
1639         status = RpcServerUseProtseqEpW(
1640             wszRpcTransport,
1641             RPC_C_PROTSEQ_MAX_REQS_DEFAULT,
1642             endpoint,
1643             NULL);
1644         if (status != RPC_S_OK)
1645             ERR("Couldn't register endpoint %s\n", debugstr_w(endpoint));
1646 
1647         /* FIXME: move remote unknown exporting into this function */
1648     }
1649     start_apartment_remote_unknown(apt);
1650 }
1651 
1652 
create_server(REFCLSID rclsid,HANDLE * process)1653 static HRESULT create_server(REFCLSID rclsid, HANDLE *process)
1654 {
1655     static const WCHAR  wszLocalServer32[] = { 'L','o','c','a','l','S','e','r','v','e','r','3','2',0 };
1656     static const WCHAR  embedding[] = { ' ', '-','E','m','b','e','d','d','i','n','g',0 };
1657     HKEY                key;
1658     HRESULT             hres;
1659     WCHAR               command[MAX_PATH+ARRAY_SIZE(embedding)];
1660     DWORD               size = (MAX_PATH+1) * sizeof(WCHAR);
1661     STARTUPINFOW        sinfo;
1662     PROCESS_INFORMATION pinfo;
1663     LONG ret;
1664 
1665     hres = COM_OpenKeyForCLSID(rclsid, wszLocalServer32, KEY_READ, &key);
1666     if (FAILED(hres)) {
1667         ERR("class %s not registered\n", debugstr_guid(rclsid));
1668         return hres;
1669     }
1670 
1671     ret = RegQueryValueExW(key, NULL, NULL, NULL, (LPBYTE)command, &size);
1672     RegCloseKey(key);
1673     if (ret) {
1674         WARN("No default value for LocalServer32 key\n");
1675         return REGDB_E_CLASSNOTREG; /* FIXME: check retval */
1676     }
1677 
1678     memset(&sinfo,0,sizeof(sinfo));
1679     sinfo.cb = sizeof(sinfo);
1680 
1681     /* EXE servers are started with the -Embedding switch. */
1682 
1683     lstrcatW(command, embedding);
1684 
1685     TRACE("activating local server %s for %s\n", debugstr_w(command), debugstr_guid(rclsid));
1686 
1687     /* FIXME: Win2003 supports a ServerExecutable value that is passed into
1688      * CreateProcess */
1689     if (!CreateProcessW(NULL, command, NULL, NULL, FALSE, DETACHED_PROCESS, NULL, NULL, &sinfo, &pinfo)) {
1690         WARN("failed to run local server %s\n", debugstr_w(command));
1691         return HRESULT_FROM_WIN32(GetLastError());
1692     }
1693     *process = pinfo.hProcess;
1694     CloseHandle(pinfo.hThread);
1695 
1696     return S_OK;
1697 }
1698 
1699 /*
1700  * start_local_service()  - start a service given its name and parameters
1701  */
start_local_service(LPCWSTR name,DWORD num,LPCWSTR * params)1702 static DWORD start_local_service(LPCWSTR name, DWORD num, LPCWSTR *params)
1703 {
1704     SC_HANDLE handle, hsvc;
1705     DWORD     r = ERROR_FUNCTION_FAILED;
1706 
1707     TRACE("Starting service %s %d params\n", debugstr_w(name), num);
1708 
1709     handle = OpenSCManagerW(NULL, NULL, SC_MANAGER_CONNECT);
1710     if (!handle)
1711         return r;
1712     hsvc = OpenServiceW(handle, name, SERVICE_START);
1713     if (hsvc)
1714     {
1715         if(StartServiceW(hsvc, num, params))
1716             r = ERROR_SUCCESS;
1717         else
1718             r = GetLastError();
1719         if (r == ERROR_SERVICE_ALREADY_RUNNING)
1720             r = ERROR_SUCCESS;
1721         CloseServiceHandle(hsvc);
1722     }
1723     else
1724         r = GetLastError();
1725     CloseServiceHandle(handle);
1726 
1727     TRACE("StartService returned error %u (%s)\n", r, (r == ERROR_SUCCESS) ? "ok":"failed");
1728 
1729     return r;
1730 }
1731 
1732 /*
1733  * create_local_service()  - start a COM server in a service
1734  *
1735  *   To start a Local Service, we read the AppID value under
1736  * the class's CLSID key, then open the HKCR\\AppId key specified
1737  * there and check for a LocalService value.
1738  *
1739  * Note:  Local Services are not supported under Windows 9x
1740  */
create_local_service(REFCLSID rclsid)1741 static HRESULT create_local_service(REFCLSID rclsid)
1742 {
1743     HRESULT hres;
1744     WCHAR buf[CHARS_IN_GUID];
1745     static const WCHAR szLocalService[] = { 'L','o','c','a','l','S','e','r','v','i','c','e',0 };
1746     static const WCHAR szServiceParams[] = {'S','e','r','v','i','c','e','P','a','r','a','m','s',0};
1747     HKEY hkey;
1748     LONG r;
1749     DWORD type, sz;
1750 
1751     TRACE("Attempting to start Local service for %s\n", debugstr_guid(rclsid));
1752 
1753     hres = COM_OpenKeyForAppIdFromCLSID(rclsid, KEY_READ, &hkey);
1754     if (FAILED(hres))
1755         return hres;
1756 
1757     /* read the LocalService and ServiceParameters values from the AppID key */
1758     sz = sizeof buf;
1759     r = RegQueryValueExW(hkey, szLocalService, NULL, &type, (LPBYTE)buf, &sz);
1760     if (r==ERROR_SUCCESS && type==REG_SZ)
1761     {
1762         DWORD num_args = 0;
1763         LPWSTR args[1] = { NULL };
1764 
1765         /*
1766          * FIXME: I'm not really sure how to deal with the service parameters.
1767          *        I suspect that the string returned from RegQueryValueExW
1768          *        should be split into a number of arguments by spaces.
1769          *        It would make more sense if ServiceParams contained a
1770          *        REG_MULTI_SZ here, but it's a REG_SZ for the services
1771          *        that I'm interested in for the moment.
1772          */
1773         r = RegQueryValueExW(hkey, szServiceParams, NULL, &type, NULL, &sz);
1774         if (r == ERROR_SUCCESS && type == REG_SZ && sz)
1775         {
1776             args[0] = HeapAlloc(GetProcessHeap(),HEAP_ZERO_MEMORY,sz);
1777             num_args++;
1778             RegQueryValueExW(hkey, szServiceParams, NULL, &type, (LPBYTE)args[0], &sz);
1779         }
1780         r = start_local_service(buf, num_args, (LPCWSTR *)args);
1781         if (r != ERROR_SUCCESS)
1782             hres = REGDB_E_CLASSNOTREG; /* FIXME: check retval */
1783         HeapFree(GetProcessHeap(),0,args[0]);
1784     }
1785     else
1786     {
1787         WARN("No LocalService value\n");
1788         hres = REGDB_E_CLASSNOTREG; /* FIXME: check retval */
1789     }
1790     RegCloseKey(hkey);
1791 
1792     return hres;
1793 }
1794 
1795 
get_localserver_pipe_name(WCHAR * pipefn,REFCLSID rclsid)1796 static void get_localserver_pipe_name(WCHAR *pipefn, REFCLSID rclsid)
1797 {
1798     static const WCHAR wszPipeRef[] = {'\\','\\','.','\\','p','i','p','e','\\',0};
1799     lstrcpyW(pipefn, wszPipeRef);
1800     StringFromGUID2(rclsid, pipefn + ARRAY_SIZE(wszPipeRef) - 1, CHARS_IN_GUID);
1801 }
1802 
1803 /* FIXME: should call to rpcss instead */
RPC_GetLocalClassObject(REFCLSID rclsid,REFIID iid,LPVOID * ppv)1804 HRESULT RPC_GetLocalClassObject(REFCLSID rclsid, REFIID iid, LPVOID *ppv)
1805 {
1806     HRESULT        hres;
1807     HANDLE         hPipe;
1808     WCHAR          pipefn[100];
1809     DWORD          res, bufferlen;
1810     char           marshalbuffer[200];
1811     IStream       *pStm;
1812     LARGE_INTEGER  seekto;
1813     ULARGE_INTEGER newpos;
1814     int            tries = 0;
1815     IServiceProvider *local_server;
1816 
1817     static const int MAXTRIES = 30; /* 30 seconds */
1818 
1819     TRACE("rclsid=%s, iid=%s\n", debugstr_guid(rclsid), debugstr_guid(iid));
1820 
1821     get_localserver_pipe_name(pipefn, rclsid);
1822 
1823     while (tries++ < MAXTRIES) {
1824         TRACE("waiting for %s\n", debugstr_w(pipefn));
1825 
1826         WaitNamedPipeW( pipefn, NMPWAIT_WAIT_FOREVER );
1827         hPipe = CreateFileW(pipefn, GENERIC_READ | GENERIC_WRITE, 0, NULL, OPEN_EXISTING, 0, 0);
1828         if (hPipe == INVALID_HANDLE_VALUE) {
1829             DWORD index;
1830             DWORD start_ticks;
1831             HANDLE process = 0;
1832             if (tries == 1) {
1833                 if ( (hres = create_local_service(rclsid)) &&
1834                      (hres = create_server(rclsid, &process)) )
1835                     return hres;
1836             } else {
1837                 WARN("Connecting to %s, no response yet, retrying: le is %u\n", debugstr_w(pipefn), GetLastError());
1838             }
1839             /* wait for one second, even if messages arrive */
1840             start_ticks = GetTickCount();
1841             do {
1842                 if (SUCCEEDED(CoWaitForMultipleHandles(0, 1000, (process != 0),
1843                                                        &process, &index)) && process && !index)
1844                 {
1845                     WARN( "server for %s failed to start\n", debugstr_guid(rclsid) );
1846                     CloseHandle( hPipe );
1847                     CloseHandle( process );
1848                     return E_NOINTERFACE;
1849                 }
1850             } while (GetTickCount() - start_ticks < 1000);
1851             if (process) CloseHandle( process );
1852             continue;
1853         }
1854         bufferlen = 0;
1855         if (!ReadFile(hPipe,marshalbuffer,sizeof(marshalbuffer),&bufferlen,NULL)) {
1856             FIXME("Failed to read marshal id from classfactory of %s.\n",debugstr_guid(rclsid));
1857             CloseHandle(hPipe);
1858             Sleep(1000);
1859             continue;
1860         }
1861         TRACE("read marshal id from pipe\n");
1862         CloseHandle(hPipe);
1863         break;
1864     }
1865 
1866     if (tries >= MAXTRIES)
1867         return E_NOINTERFACE;
1868 
1869     hres = CreateStreamOnHGlobal(0,TRUE,&pStm);
1870     if (hres != S_OK) return hres;
1871     hres = IStream_Write(pStm,marshalbuffer,bufferlen,&res);
1872     if (hres != S_OK) goto out;
1873     seekto.u.LowPart = 0;seekto.u.HighPart = 0;
1874     hres = IStream_Seek(pStm,seekto,STREAM_SEEK_SET,&newpos);
1875 
1876     TRACE("unmarshalling local server\n");
1877     hres = CoUnmarshalInterface(pStm, &IID_IServiceProvider, (void**)&local_server);
1878     if(SUCCEEDED(hres))
1879         hres = IServiceProvider_QueryService(local_server, rclsid, iid, ppv);
1880     IServiceProvider_Release(local_server);
1881 out:
1882     IStream_Release(pStm);
1883     return hres;
1884 }
1885 
1886 
1887 struct local_server_params
1888 {
1889     CLSID clsid;
1890     IStream *stream;
1891     HANDLE pipe;
1892     HANDLE stop_event;
1893     HANDLE thread;
1894     BOOL multi_use;
1895 };
1896 
1897 /* FIXME: should call to rpcss instead */
local_server_thread(LPVOID param)1898 static DWORD WINAPI local_server_thread(LPVOID param)
1899 {
1900     struct local_server_params * lsp = param;
1901     WCHAR 		pipefn[100];
1902     HRESULT		hres;
1903     IStream		*pStm = lsp->stream;
1904     STATSTG		ststg;
1905     unsigned char	*buffer;
1906     int 		buflen;
1907     LARGE_INTEGER	seekto;
1908     ULARGE_INTEGER	newpos;
1909     ULONG		res;
1910     BOOL multi_use = lsp->multi_use;
1911     OVERLAPPED ovl;
1912     HANDLE pipe_event, hPipe = lsp->pipe, new_pipe;
1913     DWORD  bytes;
1914 
1915     TRACE("Starting threader for %s.\n",debugstr_guid(&lsp->clsid));
1916 
1917     memset(&ovl, 0, sizeof(ovl));
1918     get_localserver_pipe_name(pipefn, &lsp->clsid);
1919     ovl.hEvent = pipe_event = CreateEventW(NULL, FALSE, FALSE, NULL);
1920 
1921     while (1) {
1922         if (!ConnectNamedPipe(hPipe, &ovl))
1923         {
1924             DWORD error = GetLastError();
1925             if (error == ERROR_IO_PENDING)
1926             {
1927                 HANDLE handles[2] = { pipe_event, lsp->stop_event };
1928                 DWORD ret;
1929                 ret = WaitForMultipleObjects(2, handles, FALSE, INFINITE);
1930                 if (ret != WAIT_OBJECT_0)
1931                     break;
1932             }
1933             /* client already connected isn't an error */
1934             else if (error != ERROR_PIPE_CONNECTED)
1935             {
1936                 ERR("ConnectNamedPipe failed with error %d\n", GetLastError());
1937                 break;
1938             }
1939         }
1940 
1941         TRACE("marshalling LocalServer to client\n");
1942 
1943         hres = IStream_Stat(pStm,&ststg,STATFLAG_NONAME);
1944         if (hres != S_OK)
1945             break;
1946 
1947         seekto.u.LowPart = 0;
1948         seekto.u.HighPart = 0;
1949         hres = IStream_Seek(pStm,seekto,STREAM_SEEK_SET,&newpos);
1950         if (hres != S_OK) {
1951             FIXME("IStream_Seek failed, %x\n",hres);
1952             break;
1953         }
1954 
1955         buflen = ststg.cbSize.u.LowPart;
1956         buffer = HeapAlloc(GetProcessHeap(),0,buflen);
1957 
1958         hres = IStream_Read(pStm,buffer,buflen,&res);
1959         if (hres != S_OK) {
1960             FIXME("Stream Read failed, %x\n",hres);
1961             HeapFree(GetProcessHeap(),0,buffer);
1962             break;
1963         }
1964 
1965         WriteFile(hPipe,buffer,buflen,&res,&ovl);
1966         GetOverlappedResult(hPipe, &ovl, &bytes, TRUE);
1967         HeapFree(GetProcessHeap(),0,buffer);
1968 
1969         FlushFileBuffers(hPipe);
1970         DisconnectNamedPipe(hPipe);
1971         TRACE("done marshalling LocalServer\n");
1972 
1973         if (!multi_use)
1974         {
1975             TRACE("single use object, shutting down pipe %s\n", debugstr_w(pipefn));
1976             break;
1977         }
1978         new_pipe = CreateNamedPipeW( pipefn, PIPE_ACCESS_DUPLEX | FILE_FLAG_OVERLAPPED,
1979                                      PIPE_TYPE_BYTE|PIPE_WAIT, PIPE_UNLIMITED_INSTANCES,
1980                                      4096, 4096, 500 /* 0.5 second timeout */, NULL );
1981         if (new_pipe == INVALID_HANDLE_VALUE)
1982         {
1983             FIXME("pipe creation failed for %s, le is %u\n", debugstr_w(pipefn), GetLastError());
1984             break;
1985         }
1986         CloseHandle(hPipe);
1987         hPipe = new_pipe;
1988     }
1989 
1990     CloseHandle(pipe_event);
1991     CloseHandle(hPipe);
1992     return 0;
1993 }
1994 
1995 /* starts listening for a local server */
RPC_StartLocalServer(REFCLSID clsid,IStream * stream,BOOL multi_use,void ** registration)1996 HRESULT RPC_StartLocalServer(REFCLSID clsid, IStream *stream, BOOL multi_use, void **registration)
1997 {
1998     DWORD tid, err;
1999     struct local_server_params *lsp;
2000     WCHAR pipefn[100];
2001 
2002     lsp = HeapAlloc(GetProcessHeap(), 0, sizeof(*lsp));
2003     if (!lsp)
2004         return E_OUTOFMEMORY;
2005 
2006     lsp->clsid = *clsid;
2007     lsp->stream = stream;
2008     IStream_AddRef(stream);
2009     lsp->stop_event = CreateEventW(NULL, FALSE, FALSE, NULL);
2010     if (!lsp->stop_event)
2011     {
2012         HeapFree(GetProcessHeap(), 0, lsp);
2013         return HRESULT_FROM_WIN32(GetLastError());
2014     }
2015     lsp->multi_use = multi_use;
2016 
2017     get_localserver_pipe_name(pipefn, &lsp->clsid);
2018     lsp->pipe = CreateNamedPipeW(pipefn, PIPE_ACCESS_DUPLEX | FILE_FLAG_OVERLAPPED,
2019                                  PIPE_TYPE_BYTE|PIPE_WAIT, PIPE_UNLIMITED_INSTANCES,
2020                                  4096, 4096, 500 /* 0.5 second timeout */, NULL);
2021     if (lsp->pipe == INVALID_HANDLE_VALUE)
2022     {
2023         err = GetLastError();
2024         FIXME("pipe creation failed for %s, le is %u\n", debugstr_w(pipefn), GetLastError());
2025         CloseHandle(lsp->stop_event);
2026         HeapFree(GetProcessHeap(), 0, lsp);
2027         return HRESULT_FROM_WIN32(err);
2028     }
2029 
2030     lsp->thread = CreateThread(NULL, 0, local_server_thread, lsp, 0, &tid);
2031     if (!lsp->thread)
2032     {
2033         CloseHandle(lsp->pipe);
2034         CloseHandle(lsp->stop_event);
2035         HeapFree(GetProcessHeap(), 0, lsp);
2036         return HRESULT_FROM_WIN32(GetLastError());
2037     }
2038 
2039     *registration = lsp;
2040     return S_OK;
2041 }
2042 
2043 /* stops listening for a local server */
RPC_StopLocalServer(void * registration)2044 void RPC_StopLocalServer(void *registration)
2045 {
2046     struct local_server_params *lsp = registration;
2047 
2048     /* signal local_server_thread to stop */
2049     SetEvent(lsp->stop_event);
2050     /* wait for it to exit */
2051     WaitForSingleObject(lsp->thread, INFINITE);
2052 
2053     IStream_Release(lsp->stream);
2054     CloseHandle(lsp->stop_event);
2055     CloseHandle(lsp->thread);
2056     HeapFree(GetProcessHeap(), 0, lsp);
2057 }
2058