1 /**************************************************************************
2  *
3  * Copyright 2009-2013 VMware, Inc.
4  * All Rights Reserved.
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a
7  * copy of this software and associated documentation files (the
8  * "Software"), to deal in the Software without restriction, including
9  * without limitation the rights to use, copy, modify, merge, publish,
10  * distribute, sub license, and/or sell copies of the Software, and to
11  * permit persons to whom the Software is furnished to do so, subject to
12  * the following conditions:
13  *
14  * The above copyright notice and this permission notice (including the
15  * next paragraph) shall be included in all copies or substantial portions
16  * of the Software.
17  *
18  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
19  * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT.
21  * IN NO EVENT SHALL VMWARE AND/OR ITS SUPPLIERS BE LIABLE FOR
22  * ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
23  * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
24  * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
25  *
26  **************************************************************************/
27 
28 #include <windows.h>
29 #include <tlhelp32.h>
30 
31 #include "pipe/p_compiler.h"
32 #include "util/u_debug.h"
33 #include "stw_tls.h"
34 
35 static DWORD tlsIndex = TLS_OUT_OF_INDEXES;
36 
37 
38 /**
39  * Static mutex to protect the access to g_pendingTlsData global and
40  * stw_tls_data::next member.
41  */
42 static CRITICAL_SECTION g_mutex = {
43    (PCRITICAL_SECTION_DEBUG)-1, -1, 0, 0, 0, 0
44 };
45 
46 /**
47  * There is no way to invoke TlsSetValue for a different thread, so we
48  * temporarily put the thread data for non-current threads here.
49  */
50 static struct stw_tls_data *g_pendingTlsData = NULL;
51 
52 
53 static struct stw_tls_data *
54 stw_tls_data_create(DWORD dwThreadId);
55 
56 static struct stw_tls_data *
57 stw_tls_lookup_pending_data(DWORD dwThreadId);
58 
59 
60 boolean
stw_tls_init(void)61 stw_tls_init(void)
62 {
63    tlsIndex = TlsAlloc();
64    if (tlsIndex == TLS_OUT_OF_INDEXES) {
65       return FALSE;
66    }
67 
68    /*
69     * DllMain is called with DLL_THREAD_ATTACH only for threads created after
70     * the DLL is loaded by the process.  So enumerate and add our hook to all
71     * previously existing threads.
72     *
73     * XXX: Except for the current thread since it there is an explicit
74     * stw_tls_init_thread() call for it later on.
75     */
76    if (1) {
77       DWORD dwCurrentProcessId = GetCurrentProcessId();
78       DWORD dwCurrentThreadId = GetCurrentThreadId();
79       HANDLE hSnapshot = CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, dwCurrentProcessId);
80       if (hSnapshot != INVALID_HANDLE_VALUE) {
81          THREADENTRY32 te;
82          te.dwSize = sizeof te;
83          if (Thread32First(hSnapshot, &te)) {
84             do {
85                if (te.dwSize >= FIELD_OFFSET(THREADENTRY32, th32OwnerProcessID) +
86                                 sizeof te.th32OwnerProcessID) {
87                   if (te.th32OwnerProcessID == dwCurrentProcessId) {
88                      if (te.th32ThreadID != dwCurrentThreadId) {
89                         struct stw_tls_data *data;
90                         data = stw_tls_data_create(te.th32ThreadID);
91                         if (data) {
92                            EnterCriticalSection(&g_mutex);
93                            data->next = g_pendingTlsData;
94                            g_pendingTlsData = data;
95                            LeaveCriticalSection(&g_mutex);
96                         }
97                      }
98                   }
99                }
100                te.dwSize = sizeof te;
101             } while (Thread32Next(hSnapshot, &te));
102          }
103          CloseHandle(hSnapshot);
104       }
105    }
106 
107    return TRUE;
108 }
109 
110 
111 /**
112  * Install windows hook for a given thread (not necessarily the current one).
113  */
114 static struct stw_tls_data *
stw_tls_data_create(DWORD dwThreadId)115 stw_tls_data_create(DWORD dwThreadId)
116 {
117    struct stw_tls_data *data;
118 
119    if (0) {
120       debug_printf("%s(0x%04lx)\n", __FUNCTION__, dwThreadId);
121    }
122 
123    data = calloc(1, sizeof *data);
124    if (!data) {
125       goto no_data;
126    }
127 
128    data->dwThreadId = dwThreadId;
129 
130    data->hCallWndProcHook = SetWindowsHookEx(WH_CALLWNDPROC,
131                                              stw_call_window_proc,
132                                              NULL,
133                                              dwThreadId);
134    if (data->hCallWndProcHook == NULL) {
135       goto no_hook;
136    }
137 
138    return data;
139 
140 no_hook:
141    free(data);
142 no_data:
143    return NULL;
144 }
145 
146 /**
147  * Destroy the per-thread data/hook.
148  *
149  * It is important to remove all hooks when unloading our DLL, otherwise our
150  * hook function might be called after it is no longer there.
151  */
152 static void
stw_tls_data_destroy(struct stw_tls_data * data)153 stw_tls_data_destroy(struct stw_tls_data *data)
154 {
155    assert(data);
156    if (!data) {
157       return;
158    }
159 
160    if (0) {
161       debug_printf("%s(0x%04lx)\n", __FUNCTION__, data->dwThreadId);
162    }
163 
164    if (data->hCallWndProcHook) {
165       UnhookWindowsHookEx(data->hCallWndProcHook);
166       data->hCallWndProcHook = NULL;
167    }
168 
169    free(data);
170 }
171 
172 boolean
stw_tls_init_thread(void)173 stw_tls_init_thread(void)
174 {
175    struct stw_tls_data *data;
176 
177    if (tlsIndex == TLS_OUT_OF_INDEXES) {
178       return FALSE;
179    }
180 
181    data = stw_tls_data_create(GetCurrentThreadId());
182    if (!data) {
183       return FALSE;
184    }
185 
186    TlsSetValue(tlsIndex, data);
187 
188    return TRUE;
189 }
190 
191 void
stw_tls_cleanup_thread(void)192 stw_tls_cleanup_thread(void)
193 {
194    struct stw_tls_data *data;
195 
196    if (tlsIndex == TLS_OUT_OF_INDEXES) {
197       return;
198    }
199 
200    data = (struct stw_tls_data *) TlsGetValue(tlsIndex);
201    if (data) {
202       TlsSetValue(tlsIndex, NULL);
203    } else {
204       /* See if there this thread's data in on the pending list */
205       data = stw_tls_lookup_pending_data(GetCurrentThreadId());
206    }
207 
208    if (data) {
209       stw_tls_data_destroy(data);
210    }
211 }
212 
213 void
stw_tls_cleanup(void)214 stw_tls_cleanup(void)
215 {
216    if (tlsIndex != TLS_OUT_OF_INDEXES) {
217       /*
218        * Destroy all items in g_pendingTlsData linked list.
219        */
220       EnterCriticalSection(&g_mutex);
221       while (g_pendingTlsData) {
222          struct stw_tls_data * data = g_pendingTlsData;
223          g_pendingTlsData = data->next;
224          stw_tls_data_destroy(data);
225       }
226       LeaveCriticalSection(&g_mutex);
227 
228       TlsFree(tlsIndex);
229       tlsIndex = TLS_OUT_OF_INDEXES;
230    }
231 }
232 
233 /*
234  * Search for the current thread in the g_pendingTlsData linked list.
235  *
236  * It will remove and return the node on success, or return NULL on failure.
237  */
238 static struct stw_tls_data *
stw_tls_lookup_pending_data(DWORD dwThreadId)239 stw_tls_lookup_pending_data(DWORD dwThreadId)
240 {
241    struct stw_tls_data ** p_data;
242    struct stw_tls_data *data = NULL;
243 
244    EnterCriticalSection(&g_mutex);
245    for (p_data = &g_pendingTlsData; *p_data; p_data = &(*p_data)->next) {
246       if ((*p_data)->dwThreadId == dwThreadId) {
247          data = *p_data;
248 
249 	 /*
250 	  * Unlink the node.
251 	  */
252          *p_data = data->next;
253          data->next = NULL;
254 
255 	 break;
256       }
257    }
258    LeaveCriticalSection(&g_mutex);
259 
260    return data;
261 }
262 
263 struct stw_tls_data *
stw_tls_get_data(void)264 stw_tls_get_data(void)
265 {
266    struct stw_tls_data *data;
267 
268    if (tlsIndex == TLS_OUT_OF_INDEXES) {
269       return NULL;
270    }
271 
272    data = (struct stw_tls_data *) TlsGetValue(tlsIndex);
273    if (!data) {
274       DWORD dwCurrentThreadId = GetCurrentThreadId();
275 
276       /*
277        * Search for the current thread in the g_pendingTlsData linked list.
278        */
279       data = stw_tls_lookup_pending_data(dwCurrentThreadId);
280 
281       if (!data) {
282          /*
283           * This should be impossible now.
284           */
285 	 assert(!"Failed to find thread data for thread id");
286 
287          /*
288           * DllMain is called with DLL_THREAD_ATTACH only by threads created
289           * after the DLL is loaded by the process
290           */
291          data = stw_tls_data_create(dwCurrentThreadId);
292          if (!data) {
293             return NULL;
294          }
295       }
296 
297       TlsSetValue(tlsIndex, data);
298    }
299 
300    assert(data);
301    assert(data->dwThreadId = GetCurrentThreadId());
302    assert(data->next == NULL);
303 
304    return data;
305 }
306