/************************************************************************** * * Copyright 2009-2013 VMware, Inc. * All Rights Reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sub license, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice (including the * next paragraph) shall be included in all copies or substantial portions * of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. * IN NO EVENT SHALL VMWARE AND/OR ITS SUPPLIERS BE LIABLE FOR * ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. * **************************************************************************/ #include #include #include "pipe/p_compiler.h" #include "util/u_debug.h" #include "stw_tls.h" static DWORD tlsIndex = TLS_OUT_OF_INDEXES; /** * Static mutex to protect the access to g_pendingTlsData global and * stw_tls_data::next member. */ static CRITICAL_SECTION g_mutex = { (PCRITICAL_SECTION_DEBUG)-1, -1, 0, 0, 0, 0 }; /** * There is no way to invoke TlsSetValue for a different thread, so we * temporarily put the thread data for non-current threads here. */ static struct stw_tls_data *g_pendingTlsData = NULL; static inline struct stw_tls_data * stw_tls_data_create(DWORD dwThreadId); static struct stw_tls_data * stw_tls_lookup_pending_data(DWORD dwThreadId); boolean stw_tls_init(void) { tlsIndex = TlsAlloc(); if (tlsIndex == TLS_OUT_OF_INDEXES) { return FALSE; } /* * DllMain is called with DLL_THREAD_ATTACH only for threads created after * the DLL is loaded by the process. So enumerate and add our hook to all * previously existing threads. * * XXX: Except for the current thread since it there is an explicit * stw_tls_init_thread() call for it later on. */ if (1) { DWORD dwCurrentProcessId = GetCurrentProcessId(); DWORD dwCurrentThreadId = GetCurrentThreadId(); HANDLE hSnapshot = CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, dwCurrentProcessId); if (hSnapshot != INVALID_HANDLE_VALUE) { THREADENTRY32 te; te.dwSize = sizeof te; if (Thread32First(hSnapshot, &te)) { do { if (te.dwSize >= FIELD_OFFSET(THREADENTRY32, th32OwnerProcessID) + sizeof te.th32OwnerProcessID) { if (te.th32OwnerProcessID == dwCurrentProcessId) { if (te.th32ThreadID != dwCurrentThreadId) { struct stw_tls_data *data; data = stw_tls_data_create(te.th32ThreadID); if (data) { EnterCriticalSection(&g_mutex); data->next = g_pendingTlsData; g_pendingTlsData = data; LeaveCriticalSection(&g_mutex); } } } } te.dwSize = sizeof te; } while (Thread32Next(hSnapshot, &te)); } CloseHandle(hSnapshot); } } return TRUE; } /** * Install windows hook for a given thread (not necessarily the current one). */ static inline struct stw_tls_data * stw_tls_data_create(DWORD dwThreadId) { struct stw_tls_data *data; if (0) { debug_printf("%s(0x%04lx)\n", __FUNCTION__, dwThreadId); } data = calloc(1, sizeof *data); if (!data) { goto no_data; } data->dwThreadId = dwThreadId; data->hCallWndProcHook = SetWindowsHookEx(WH_CALLWNDPROC, stw_call_window_proc, NULL, dwThreadId); if (data->hCallWndProcHook == NULL) { goto no_hook; } return data; no_hook: free(data); no_data: return NULL; } /** * Destroy the per-thread data/hook. * * It is important to remove all hooks when unloading our DLL, otherwise our * hook function might be called after it is no longer there. */ static void stw_tls_data_destroy(struct stw_tls_data *data) { assert(data); if (!data) { return; } if (0) { debug_printf("%s(0x%04lx)\n", __FUNCTION__, data->dwThreadId); } if (data->hCallWndProcHook) { UnhookWindowsHookEx(data->hCallWndProcHook); data->hCallWndProcHook = NULL; } free(data); } boolean stw_tls_init_thread(void) { struct stw_tls_data *data; if (tlsIndex == TLS_OUT_OF_INDEXES) { return FALSE; } data = stw_tls_data_create(GetCurrentThreadId()); if (!data) { return FALSE; } TlsSetValue(tlsIndex, data); return TRUE; } void stw_tls_cleanup_thread(void) { struct stw_tls_data *data; if (tlsIndex == TLS_OUT_OF_INDEXES) { return; } data = (struct stw_tls_data *) TlsGetValue(tlsIndex); if (data) { TlsSetValue(tlsIndex, NULL); } else { /* See if there this thread's data in on the pending list */ data = stw_tls_lookup_pending_data(GetCurrentThreadId()); } if (data) { stw_tls_data_destroy(data); } } void stw_tls_cleanup(void) { if (tlsIndex != TLS_OUT_OF_INDEXES) { /* * Destroy all items in g_pendingTlsData linked list. */ EnterCriticalSection(&g_mutex); while (g_pendingTlsData) { struct stw_tls_data * data = g_pendingTlsData; g_pendingTlsData = data->next; stw_tls_data_destroy(data); } LeaveCriticalSection(&g_mutex); TlsFree(tlsIndex); tlsIndex = TLS_OUT_OF_INDEXES; } } /* * Search for the current thread in the g_pendingTlsData linked list. * * It will remove and return the node on success, or return NULL on failure. */ static struct stw_tls_data * stw_tls_lookup_pending_data(DWORD dwThreadId) { struct stw_tls_data ** p_data; struct stw_tls_data *data = NULL; EnterCriticalSection(&g_mutex); for (p_data = &g_pendingTlsData; *p_data; p_data = &(*p_data)->next) { if ((*p_data)->dwThreadId == dwThreadId) { data = *p_data; /* * Unlink the node. */ *p_data = data->next; data->next = NULL; break; } } LeaveCriticalSection(&g_mutex); return data; } struct stw_tls_data * stw_tls_get_data(void) { struct stw_tls_data *data; if (tlsIndex == TLS_OUT_OF_INDEXES) { return NULL; } data = (struct stw_tls_data *) TlsGetValue(tlsIndex); if (!data) { DWORD dwCurrentThreadId = GetCurrentThreadId(); /* * Search for the current thread in the g_pendingTlsData linked list. */ data = stw_tls_lookup_pending_data(dwCurrentThreadId); if (!data) { /* * This should be impossible now. */ assert(!"Failed to find thread data for thread id"); /* * DllMain is called with DLL_THREAD_ATTACH only by threads created * after the DLL is loaded by the process */ data = stw_tls_data_create(dwCurrentThreadId); if (!data) { return NULL; } } TlsSetValue(tlsIndex, data); } assert(data); assert(data->dwThreadId = GetCurrentThreadId()); assert(data->next == NULL); return data; }