xref: /reactos/dll/win32/rpcrt4/rpc_assoc.c (revision c2c66aff)
1 /*
2  * Associations
3  *
4  * Copyright 2007 Robert Shearman (for CodeWeavers)
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
19  *
20  */
21 
22 #include "precomp.h"
23 
24 WINE_DEFAULT_DEBUG_CHANNEL(rpc);
25 
26 static CRITICAL_SECTION assoc_list_cs;
27 static CRITICAL_SECTION_DEBUG assoc_list_cs_debug =
28 {
29     0, 0, &assoc_list_cs,
30     { &assoc_list_cs_debug.ProcessLocksList, &assoc_list_cs_debug.ProcessLocksList },
31       0, 0, { (DWORD_PTR)(__FILE__ ": assoc_list_cs") }
32 };
33 static CRITICAL_SECTION assoc_list_cs = { &assoc_list_cs_debug, -1, 0, 0, 0, 0 };
34 
35 static struct list client_assoc_list = LIST_INIT(client_assoc_list);
36 static struct list server_assoc_list = LIST_INIT(server_assoc_list);
37 
38 static LONG last_assoc_group_id;
39 
40 typedef struct _RpcContextHandle
41 {
42     struct list entry;
43     void *user_context;
44     NDR_RUNDOWN rundown_routine;
45     void *ctx_guard;
46     UUID uuid;
47     RTL_RWLOCK rw_lock;
48     unsigned int refs;
49 } RpcContextHandle;
50 
51 static void RpcContextHandle_Destroy(RpcContextHandle *context_handle);
52 
53 static RPC_STATUS RpcAssoc_Alloc(LPCSTR Protseq, LPCSTR NetworkAddr,
54                                  LPCSTR Endpoint, LPCWSTR NetworkOptions,
55                                  RpcAssoc **assoc_out)
56 {
57     RpcAssoc *assoc;
58     assoc = HeapAlloc(GetProcessHeap(), 0, sizeof(*assoc));
59     if (!assoc)
60         return RPC_S_OUT_OF_RESOURCES;
61     assoc->refs = 1;
62     list_init(&assoc->free_connection_pool);
63     list_init(&assoc->context_handle_list);
64     InitializeCriticalSection(&assoc->cs);
65     assoc->cs.DebugInfo->Spare[0] = (DWORD_PTR)(__FILE__ ": RpcAssoc.cs");
66     assoc->Protseq = RPCRT4_strdupA(Protseq);
67     assoc->NetworkAddr = RPCRT4_strdupA(NetworkAddr);
68     assoc->Endpoint = RPCRT4_strdupA(Endpoint);
69     assoc->NetworkOptions = NetworkOptions ? RPCRT4_strdupW(NetworkOptions) : NULL;
70     assoc->assoc_group_id = 0;
71     UuidCreate(&assoc->http_uuid);
72     list_init(&assoc->entry);
73     *assoc_out = assoc;
74     return RPC_S_OK;
75 }
76 
77 static BOOL compare_networkoptions(LPCWSTR opts1, LPCWSTR opts2)
78 {
79     if ((opts1 == NULL) && (opts2 == NULL))
80         return TRUE;
81     if ((opts1 == NULL) || (opts2 == NULL))
82         return FALSE;
83     return !strcmpW(opts1, opts2);
84 }
85 
86 RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
87                                  LPCSTR Endpoint, LPCWSTR NetworkOptions,
88                                  RpcAssoc **assoc_out)
89 {
90     RpcAssoc *assoc;
91     RPC_STATUS status;
92 
93     EnterCriticalSection(&assoc_list_cs);
94     LIST_FOR_EACH_ENTRY(assoc, &client_assoc_list, RpcAssoc, entry)
95     {
96         if (!strcmp(Protseq, assoc->Protseq) &&
97             !strcmp(NetworkAddr, assoc->NetworkAddr) &&
98             !strcmp(Endpoint, assoc->Endpoint) &&
99             compare_networkoptions(NetworkOptions, assoc->NetworkOptions))
100         {
101             assoc->refs++;
102             *assoc_out = assoc;
103             LeaveCriticalSection(&assoc_list_cs);
104             TRACE("using existing assoc %p\n", assoc);
105             return RPC_S_OK;
106         }
107     }
108 
109     status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
110     if (status != RPC_S_OK)
111     {
112         LeaveCriticalSection(&assoc_list_cs);
113         return status;
114     }
115     list_add_head(&client_assoc_list, &assoc->entry);
116     *assoc_out = assoc;
117 
118     LeaveCriticalSection(&assoc_list_cs);
119 
120     TRACE("new assoc %p\n", assoc);
121 
122     return RPC_S_OK;
123 }
124 
125 RPC_STATUS RpcServerAssoc_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
126                                          LPCSTR Endpoint, LPCWSTR NetworkOptions,
127                                          ULONG assoc_gid,
128                                          RpcAssoc **assoc_out)
129 {
130     RpcAssoc *assoc;
131     RPC_STATUS status;
132 
133     EnterCriticalSection(&assoc_list_cs);
134     if (assoc_gid)
135     {
136         LIST_FOR_EACH_ENTRY(assoc, &server_assoc_list, RpcAssoc, entry)
137         {
138             /* FIXME: NetworkAddr shouldn't be NULL */
139             if (assoc->assoc_group_id == assoc_gid &&
140                 !strcmp(Protseq, assoc->Protseq) &&
141                 (!NetworkAddr || !assoc->NetworkAddr || !strcmp(NetworkAddr, assoc->NetworkAddr)) &&
142                 !strcmp(Endpoint, assoc->Endpoint) &&
143                 ((!assoc->NetworkOptions == !NetworkOptions) &&
144                  (!NetworkOptions || !strcmpW(NetworkOptions, assoc->NetworkOptions))))
145             {
146                 assoc->refs++;
147                 *assoc_out = assoc;
148                 LeaveCriticalSection(&assoc_list_cs);
149                 TRACE("using existing assoc %p\n", assoc);
150                 return RPC_S_OK;
151             }
152         }
153         *assoc_out = NULL;
154         LeaveCriticalSection(&assoc_list_cs);
155         return RPC_S_NO_CONTEXT_AVAILABLE;
156     }
157 
158     status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
159     if (status != RPC_S_OK)
160     {
161         LeaveCriticalSection(&assoc_list_cs);
162         return status;
163     }
164     assoc->assoc_group_id = InterlockedIncrement(&last_assoc_group_id);
165     list_add_head(&server_assoc_list, &assoc->entry);
166     *assoc_out = assoc;
167 
168     LeaveCriticalSection(&assoc_list_cs);
169 
170     TRACE("new assoc %p\n", assoc);
171 
172     return RPC_S_OK;
173 }
174 
175 ULONG RpcAssoc_Release(RpcAssoc *assoc)
176 {
177     ULONG refs;
178 
179     EnterCriticalSection(&assoc_list_cs);
180     refs = --assoc->refs;
181     if (!refs)
182         list_remove(&assoc->entry);
183     LeaveCriticalSection(&assoc_list_cs);
184 
185     if (!refs)
186     {
187         RpcConnection *Connection, *cursor2;
188         RpcContextHandle *context_handle, *context_handle_cursor;
189 
190         TRACE("destroying assoc %p\n", assoc);
191 
192         LIST_FOR_EACH_ENTRY_SAFE(Connection, cursor2, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
193         {
194             list_remove(&Connection->conn_pool_entry);
195             RPCRT4_ReleaseConnection(Connection);
196         }
197 
198         LIST_FOR_EACH_ENTRY_SAFE(context_handle, context_handle_cursor, &assoc->context_handle_list, RpcContextHandle, entry)
199             RpcContextHandle_Destroy(context_handle);
200 
201         HeapFree(GetProcessHeap(), 0, assoc->NetworkOptions);
202         HeapFree(GetProcessHeap(), 0, assoc->Endpoint);
203         HeapFree(GetProcessHeap(), 0, assoc->NetworkAddr);
204         HeapFree(GetProcessHeap(), 0, assoc->Protseq);
205 
206         assoc->cs.DebugInfo->Spare[0] = 0;
207         DeleteCriticalSection(&assoc->cs);
208 
209         HeapFree(GetProcessHeap(), 0, assoc);
210     }
211 
212     return refs;
213 }
214 
215 #define ROUND_UP(value, alignment) (((value) + ((alignment) - 1)) & ~((alignment)-1))
216 
217 static RPC_STATUS RpcAssoc_BindConnection(const RpcAssoc *assoc, RpcConnection *conn,
218                                           const RPC_SYNTAX_IDENTIFIER *InterfaceId,
219                                           const RPC_SYNTAX_IDENTIFIER *TransferSyntax)
220 {
221     RpcPktHdr *hdr;
222     RpcPktHdr *response_hdr;
223     RPC_MESSAGE msg;
224     RPC_STATUS status;
225     unsigned char *auth_data = NULL;
226     ULONG auth_length;
227 
228     TRACE("sending bind request to server\n");
229 
230     hdr = RPCRT4_BuildBindHeader(NDR_LOCAL_DATA_REPRESENTATION,
231                                  RPC_MAX_PACKET_SIZE, RPC_MAX_PACKET_SIZE,
232                                  assoc->assoc_group_id,
233                                  InterfaceId, TransferSyntax);
234 
235     status = RPCRT4_Send(conn, hdr, NULL, 0);
236     RPCRT4_FreeHeader(hdr);
237     if (status != RPC_S_OK)
238         return status;
239 
240     status = RPCRT4_ReceiveWithAuth(conn, &response_hdr, &msg, &auth_data, &auth_length);
241     if (status != RPC_S_OK)
242     {
243         ERR("receive failed with error %d\n", status);
244         return status;
245     }
246 
247     switch (response_hdr->common.ptype)
248     {
249     case PKT_BIND_ACK:
250     {
251         RpcAddressString *server_address = msg.Buffer;
252         if ((msg.BufferLength >= FIELD_OFFSET(RpcAddressString, string[0])) ||
253             (msg.BufferLength >= ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4)))
254         {
255             unsigned short remaining = msg.BufferLength -
256             ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4);
257             RpcResultList *results = (RpcResultList*)((ULONG_PTR)server_address +
258                 ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4));
259             if ((results->num_results == 1) &&
260                 (remaining >= FIELD_OFFSET(RpcResultList, results[results->num_results])))
261             {
262                 switch (results->results[0].result)
263                 {
264                 case RESULT_ACCEPT:
265                     /* respond to authorization request */
266                     if (auth_length > sizeof(RpcAuthVerifier))
267                         status = RPCRT4_ClientConnectionAuth(conn,
268                                                              auth_data + sizeof(RpcAuthVerifier),
269                                                              auth_length);
270                     if (status == RPC_S_OK)
271                     {
272                         conn->assoc_group_id = response_hdr->bind_ack.assoc_gid;
273                         conn->MaxTransmissionSize = response_hdr->bind_ack.max_tsize;
274                         conn->ActiveInterface = *InterfaceId;
275                     }
276                     break;
277                 case RESULT_PROVIDER_REJECTION:
278                     switch (results->results[0].reason)
279                     {
280                     case REASON_ABSTRACT_SYNTAX_NOT_SUPPORTED:
281                         ERR("syntax %s, %d.%d not supported\n",
282                             debugstr_guid(&InterfaceId->SyntaxGUID),
283                             InterfaceId->SyntaxVersion.MajorVersion,
284                             InterfaceId->SyntaxVersion.MinorVersion);
285                         status = RPC_S_UNKNOWN_IF;
286                         break;
287                     case REASON_TRANSFER_SYNTAXES_NOT_SUPPORTED:
288                         ERR("transfer syntax not supported\n");
289                         status = RPC_S_SERVER_UNAVAILABLE;
290                         break;
291                     case REASON_NONE:
292                     default:
293                         status = RPC_S_CALL_FAILED_DNE;
294                     }
295                     break;
296                 case RESULT_USER_REJECTION:
297                 default:
298                     ERR("rejection result %d\n", results->results[0].result);
299                     status = RPC_S_CALL_FAILED_DNE;
300                 }
301             }
302             else
303             {
304                 ERR("incorrect results size\n");
305                 status = RPC_S_CALL_FAILED_DNE;
306             }
307         }
308         else
309         {
310             ERR("bind ack packet too small (%d)\n", msg.BufferLength);
311             status = RPC_S_PROTOCOL_ERROR;
312         }
313         break;
314     }
315     case PKT_BIND_NACK:
316         switch (response_hdr->bind_nack.reject_reason)
317         {
318         case REJECT_LOCAL_LIMIT_EXCEEDED:
319         case REJECT_TEMPORARY_CONGESTION:
320             ERR("server too busy\n");
321             status = RPC_S_SERVER_TOO_BUSY;
322             break;
323         case REJECT_PROTOCOL_VERSION_NOT_SUPPORTED:
324             ERR("protocol version not supported\n");
325             status = RPC_S_PROTOCOL_ERROR;
326             break;
327         case REJECT_UNKNOWN_AUTHN_SERVICE:
328             ERR("unknown authentication service\n");
329             status = RPC_S_UNKNOWN_AUTHN_SERVICE;
330             break;
331         case REJECT_INVALID_CHECKSUM:
332             ERR("invalid checksum\n");
333             status = RPC_S_ACCESS_DENIED;
334             break;
335         default:
336             ERR("rejected bind for reason %d\n", response_hdr->bind_nack.reject_reason);
337             status = RPC_S_CALL_FAILED_DNE;
338         }
339         break;
340     default:
341         ERR("wrong packet type received %d\n", response_hdr->common.ptype);
342         status = RPC_S_PROTOCOL_ERROR;
343         break;
344     }
345 
346     I_RpcFree(msg.Buffer);
347     RPCRT4_FreeHeader(response_hdr);
348     HeapFree(GetProcessHeap(), 0, auth_data);
349     return status;
350 }
351 
352 static RpcConnection *RpcAssoc_GetIdleConnection(RpcAssoc *assoc,
353                                                  const RPC_SYNTAX_IDENTIFIER *InterfaceId,
354                                                  const RPC_SYNTAX_IDENTIFIER *TransferSyntax, const RpcAuthInfo *AuthInfo,
355                                                  const RpcQualityOfService *QOS)
356 {
357     RpcConnection *Connection;
358     EnterCriticalSection(&assoc->cs);
359     /* try to find a compatible connection from the connection pool */
360     LIST_FOR_EACH_ENTRY(Connection, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
361     {
362         if (!memcmp(&Connection->ActiveInterface, InterfaceId,
363                     sizeof(RPC_SYNTAX_IDENTIFIER)) &&
364             RpcAuthInfo_IsEqual(Connection->AuthInfo, AuthInfo) &&
365             RpcQualityOfService_IsEqual(Connection->QOS, QOS))
366         {
367             list_remove(&Connection->conn_pool_entry);
368             LeaveCriticalSection(&assoc->cs);
369             TRACE("got connection from pool %p\n", Connection);
370             return Connection;
371         }
372     }
373 
374     LeaveCriticalSection(&assoc->cs);
375     return NULL;
376 }
377 
378 RPC_STATUS RpcAssoc_GetClientConnection(RpcAssoc *assoc,
379                                         const RPC_SYNTAX_IDENTIFIER *InterfaceId,
380                                         const RPC_SYNTAX_IDENTIFIER *TransferSyntax, RpcAuthInfo *AuthInfo,
381                                         RpcQualityOfService *QOS, LPCWSTR CookieAuth, RpcConnection **Connection)
382 {
383     RpcConnection *NewConnection;
384     RPC_STATUS status;
385 
386     *Connection = RpcAssoc_GetIdleConnection(assoc, InterfaceId, TransferSyntax, AuthInfo, QOS);
387     if (*Connection)
388         return RPC_S_OK;
389 
390     /* create a new connection */
391     status = RPCRT4_CreateConnection(&NewConnection, FALSE /* is this a server connection? */,
392         assoc->Protseq, assoc->NetworkAddr,
393         assoc->Endpoint, assoc->NetworkOptions,
394         AuthInfo, QOS, CookieAuth);
395     if (status != RPC_S_OK)
396         return status;
397 
398     NewConnection->assoc = assoc;
399     status = RPCRT4_OpenClientConnection(NewConnection);
400     if (status != RPC_S_OK)
401     {
402         RPCRT4_ReleaseConnection(NewConnection);
403         return status;
404     }
405 
406     status = RpcAssoc_BindConnection(assoc, NewConnection, InterfaceId, TransferSyntax);
407     if (status != RPC_S_OK)
408     {
409         RPCRT4_ReleaseConnection(NewConnection);
410         return status;
411     }
412 
413     *Connection = NewConnection;
414 
415     return RPC_S_OK;
416 }
417 
418 void RpcAssoc_ReleaseIdleConnection(RpcAssoc *assoc, RpcConnection *Connection)
419 {
420     assert(!Connection->server);
421     Connection->async_state = NULL;
422     EnterCriticalSection(&assoc->cs);
423     if (!assoc->assoc_group_id) assoc->assoc_group_id = Connection->assoc_group_id;
424     list_add_head(&assoc->free_connection_pool, &Connection->conn_pool_entry);
425     LeaveCriticalSection(&assoc->cs);
426 }
427 
428 RPC_STATUS RpcServerAssoc_AllocateContextHandle(RpcAssoc *assoc, void *CtxGuard,
429                                                 NDR_SCONTEXT *SContext)
430 {
431     RpcContextHandle *context_handle;
432 
433     context_handle = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(*context_handle));
434     if (!context_handle)
435         return RPC_S_OUT_OF_MEMORY;
436 
437     context_handle->ctx_guard = CtxGuard;
438     RtlInitializeResource(&context_handle->rw_lock);
439     context_handle->refs = 1;
440 
441     /* lock here to mirror unmarshall, so we don't need to special-case the
442      * freeing of a non-marshalled context handle */
443     RtlAcquireResourceExclusive(&context_handle->rw_lock, TRUE);
444 
445     EnterCriticalSection(&assoc->cs);
446     list_add_tail(&assoc->context_handle_list, &context_handle->entry);
447     LeaveCriticalSection(&assoc->cs);
448 
449     *SContext = (NDR_SCONTEXT)context_handle;
450     return RPC_S_OK;
451 }
452 
453 BOOL RpcContextHandle_IsGuardCorrect(NDR_SCONTEXT SContext, void *CtxGuard)
454 {
455     RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
456     return context_handle->ctx_guard == CtxGuard;
457 }
458 
459 RPC_STATUS RpcServerAssoc_FindContextHandle(RpcAssoc *assoc, const UUID *uuid,
460                                             void *CtxGuard, ULONG Flags, NDR_SCONTEXT *SContext)
461 {
462     RpcContextHandle *context_handle;
463 
464     EnterCriticalSection(&assoc->cs);
465     LIST_FOR_EACH_ENTRY(context_handle, &assoc->context_handle_list, RpcContextHandle, entry)
466     {
467         if (RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard) &&
468             !memcmp(&context_handle->uuid, uuid, sizeof(*uuid)))
469         {
470             *SContext = (NDR_SCONTEXT)context_handle;
471             if (context_handle->refs++)
472             {
473                 LeaveCriticalSection(&assoc->cs);
474                 TRACE("found %p\n", context_handle);
475                 RtlAcquireResourceExclusive(&context_handle->rw_lock, TRUE);
476                 return RPC_S_OK;
477             }
478         }
479     }
480     LeaveCriticalSection(&assoc->cs);
481 
482     ERR("no context handle found for uuid %s, guard %p\n",
483         debugstr_guid(uuid), CtxGuard);
484     return ERROR_INVALID_HANDLE;
485 }
486 
487 RPC_STATUS RpcServerAssoc_UpdateContextHandle(RpcAssoc *assoc,
488                                               NDR_SCONTEXT SContext,
489                                               void *CtxGuard,
490                                               NDR_RUNDOWN rundown_routine)
491 {
492     RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
493     RPC_STATUS status;
494 
495     if (!RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard))
496         return ERROR_INVALID_HANDLE;
497 
498     EnterCriticalSection(&assoc->cs);
499     if (UuidIsNil(&context_handle->uuid, &status))
500     {
501         /* add a ref for the data being valid */
502         context_handle->refs++;
503         UuidCreate(&context_handle->uuid);
504         context_handle->rundown_routine = rundown_routine;
505         TRACE("allocated uuid %s for context handle %p\n",
506               debugstr_guid(&context_handle->uuid), context_handle);
507     }
508     LeaveCriticalSection(&assoc->cs);
509 
510     return RPC_S_OK;
511 }
512 
513 void RpcContextHandle_GetUuid(NDR_SCONTEXT SContext, UUID *uuid)
514 {
515     RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
516     *uuid = context_handle->uuid;
517 }
518 
519 static void RpcContextHandle_Destroy(RpcContextHandle *context_handle)
520 {
521     TRACE("freeing %p\n", context_handle);
522 
523     if (context_handle->user_context && context_handle->rundown_routine)
524     {
525         TRACE("calling rundown routine %p with user context %p\n",
526               context_handle->rundown_routine, context_handle->user_context);
527         context_handle->rundown_routine(context_handle->user_context);
528     }
529 
530     RtlDeleteResource(&context_handle->rw_lock);
531 
532     HeapFree(GetProcessHeap(), 0, context_handle);
533 }
534 
535 unsigned int RpcServerAssoc_ReleaseContextHandle(RpcAssoc *assoc, NDR_SCONTEXT SContext, BOOL release_lock)
536 {
537     RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
538     unsigned int refs;
539 
540     if (release_lock)
541         RtlReleaseResource(&context_handle->rw_lock);
542 
543     EnterCriticalSection(&assoc->cs);
544     refs = --context_handle->refs;
545     if (!refs)
546         list_remove(&context_handle->entry);
547     LeaveCriticalSection(&assoc->cs);
548 
549     if (!refs)
550         RpcContextHandle_Destroy(context_handle);
551 
552     return refs;
553 }
554