1 /*
2  * PROJECT:         ReactOS api tests
3  * LICENSE:         BSD - See COPYING.ARM in the top level directory
4  * PURPOSE:         Tests for IInitializeSpy
5  * PROGRAMMERS:     Mark Jansen
6  */
7 
8 #define WIN32_NO_STATUS
9 #define _INC_WINDOWS
10 #define COM_NO_WINDOWS_H
11 
12 #include <stdio.h>
13 #include <wine/test.h>
14 
15 #include <winuser.h>
16 #include <winreg.h>
17 
18 #include <shlwapi.h>
19 #include <unknownbase.h>
20 
21 #define test_S_OK(hres, message) ok((hres) == S_OK, "%s (0x%lx instead of S_OK)\n", (message), (hres))
22 #define test_HRES(hres, hresExpected, message) ok((hres) == (hresExpected), "%s (0x%lx instead of 0x%lx)\n", (message), (hres), (hresExpected))
23 #define test_ref(spy, expectedRef) ok((spy)->GetRef() == (expectedRef), "unexpected refcount, %ld instead of %d\n", (spy)->GetRef(), (expectedRef))
24 
25 
26 typedef HRESULT (WINAPI *pCoRegisterInitializeSpy_t)(_In_ LPINITIALIZESPY pSpy, _Out_ ULARGE_INTEGER *puliCookie);
27 typedef HRESULT (WINAPI *pCoRevokeInitializeSpy_t)(_In_ ULARGE_INTEGER uliCookie);
28 pCoRegisterInitializeSpy_t pCoRegisterInitializeSpy;
29 pCoRevokeInitializeSpy_t pCoRevokeInitializeSpy;
30 
31 
32 const DWORD INVALID_VALUE = 0xdeadbeef;
33 
34 
35 class CTestSpy : public CUnknownBase<IInitializeSpy>
36 {
37 public:
38     HRESULT hr;
39     ULARGE_INTEGER Cookie;
40 
41     // expected values to check against
42     HRESULT m_hrCoInit;
43     DWORD m_CoInit;
44     DWORD m_CurAptRefs;
45 
46     // keeping count of the times called
47     LONG m_PreInitCalled;
48     LONG m_PostInitCalled;
49     LONG m_PreUninitCalled;
50     LONG m_PostUninitCalled;
51 
52     // fake out some
53     bool m_FailQueryInterface;
54     bool m_AlwaysReturnOK;
55 
CTestSpy()56     CTestSpy()
57         : CUnknownBase( false, 0 ),
58         hr(0),
59         m_hrCoInit(0),
60         m_CoInit(0),
61         m_CurAptRefs(0),
62         m_FailQueryInterface(false),
63         m_AlwaysReturnOK(false)
64     {
65         Cookie.HighPart = Cookie.LowPart = INVALID_VALUE;
66         Clear();
67     }
68 
~CTestSpy()69     ~CTestSpy()
70     {
71         // always try to revoke if we succeeded to register.
72         if (SUCCEEDED(hr))
73         {
74             hr = pCoRevokeInitializeSpy(Cookie);
75             test_S_OK(hr, "CoRevokeInitializeSpy");
76         }
77         // we should be done.
78         ok(GetRef() == 0, "Expected m_lRef to be 0, was: %ld\n", GetRef());
79     }
80 
QueryInterface(REFIID riid,void ** ppv)81     HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, void** ppv)
82     {
83         if (m_FailQueryInterface)
84         {
85             return E_NOINTERFACE;
86         }
87         return CUnknownBase::QueryInterface(riid, ppv);
88     }
89 
GetQITab()90     const QITAB* GetQITab()
91     {
92         static const QITAB tab[] = { { &IID_IInitializeSpy, OFFSETOFCLASS(IInitializeSpy, CTestSpy) }, { 0 } };
93         return tab;
94     }
95 
96 
PreInitialize(DWORD dwCoInit,DWORD dwCurThreadAptRefs)97     HRESULT STDMETHODCALLTYPE PreInitialize(DWORD dwCoInit, DWORD dwCurThreadAptRefs)
98     {
99         InterlockedIncrement(&m_PreInitCalled);
100         ok(m_CoInit == dwCoInit, "Unexpected dwCoInit: got %lx, expected %lx\n", dwCoInit, m_CoInit);
101         DWORD expectApt = m_hrCoInit == RPC_E_CHANGED_MODE ? m_CurAptRefs : m_CurAptRefs -1;
102         ok(expectApt == dwCurThreadAptRefs, "Unexpected dwCurThreadAptRefs: got %lx, expected %lx\n", dwCurThreadAptRefs, expectApt);
103         return S_OK;
104     }
105 
PostInitialize(HRESULT hrCoInit,DWORD dwCoInit,DWORD dwNewThreadAptRefs)106     HRESULT STDMETHODCALLTYPE PostInitialize(HRESULT hrCoInit, DWORD dwCoInit, DWORD dwNewThreadAptRefs)
107     {
108         InterlockedIncrement(&m_PostInitCalled);
109         ok(m_PreInitCalled == m_PostInitCalled, "Expected balanced pre/post: %ld / %ld\n", m_PreInitCalled, m_PostInitCalled);
110         test_HRES(hrCoInit, m_hrCoInit, "Unexpected hrCoInit in PostInitialize");
111         ok(m_CoInit == dwCoInit, "Unexpected dwCoInit: got %lx, expected %lx\n", dwCoInit, m_CoInit);
112         ok(m_CurAptRefs == dwNewThreadAptRefs, "Unexpected dwNewThreadAptRefs: got %lx, expected %lx\n", dwNewThreadAptRefs, m_CurAptRefs);
113         if (m_AlwaysReturnOK)
114             return S_OK;
115         return hrCoInit;
116     }
117 
PreUninitialize(DWORD dwCurThreadAptRefs)118     HRESULT STDMETHODCALLTYPE PreUninitialize(DWORD dwCurThreadAptRefs)
119     {
120         InterlockedIncrement(&m_PreUninitCalled);
121         ok(m_CurAptRefs == dwCurThreadAptRefs, "Unexpected dwCurThreadAptRefs: got %lx, expected %lx\n", dwCurThreadAptRefs, m_CurAptRefs);
122         return S_OK;
123     }
124 
PostUninitialize(DWORD dwNewThreadAptRefs)125     HRESULT STDMETHODCALLTYPE PostUninitialize(DWORD dwNewThreadAptRefs)
126     {
127         InterlockedIncrement(&m_PostUninitCalled);
128         ok(m_PreUninitCalled == m_PostUninitCalled, "Expected balanced pre/post: %ld / %ld\n", m_PreUninitCalled, m_PostUninitCalled);
129         DWORD apt = m_CurAptRefs ? (m_CurAptRefs-1) : 0;
130         ok(apt == dwNewThreadAptRefs, "Unexpected dwNewThreadAptRefs: got %lx, expected %lx\n", dwNewThreadAptRefs, apt);
131         return S_OK;
132     }
133 
Clear()134     void Clear()
135     {
136         m_PreInitCalled = 0;
137         m_PostInitCalled = 0;
138         m_PreUninitCalled = 0;
139         m_PostUninitCalled = 0;
140     }
141 
Expect(HRESULT hrCoInit,DWORD CoInit,DWORD CurAptRefs)142     void Expect(HRESULT hrCoInit, DWORD CoInit, DWORD CurAptRefs)
143     {
144         m_hrCoInit = hrCoInit;
145         m_CoInit = CoInit;
146         m_CurAptRefs = CurAptRefs;
147     }
148 
Check(LONG PreInit,LONG PostInit,LONG PreUninit,LONG PostUninit)149     void Check(LONG PreInit, LONG PostInit, LONG PreUninit, LONG PostUninit)
150     {
151         ok(m_PreInitCalled == PreInit, "Expected PreInit to be %ld, was: %ld\n", PreInit, m_PreInitCalled);
152         ok(m_PostInitCalled == PostInit, "Expected PostInit to be %ld, was: %ld\n", PostInit, m_PostInitCalled);
153         ok(m_PreUninitCalled == PreUninit, "Expected PreUninit to be %ld, was: %ld\n", PreUninit, m_PreUninitCalled);
154         ok(m_PostUninitCalled == PostUninit, "Expected PostUninit to be %ld, was: %ld\n", PostUninit, m_PostUninitCalled);
155     }
156 };
157 
158 
test_IInitializeSpy_register2()159 void test_IInitializeSpy_register2()
160 {
161     CTestSpy spy, spy2;
162 
163     // first we register 2 spies
164     spy.hr = pCoRegisterInitializeSpy(&spy, &spy.Cookie);
165     test_S_OK(spy.hr, "CoRegisterInitializeSpy");
166     test_ref(&spy, 1);
167 
168     spy2.hr = pCoRegisterInitializeSpy(&spy2, &spy2.Cookie);
169     test_S_OK(spy2.hr, "CoRegisterInitializeSpy");
170     test_ref(&spy, 1);
171 
172     // tell them what we expect
173     spy.Expect(S_OK, COINIT_APARTMENTTHREADED, 1);
174     spy2.Expect(S_OK, COINIT_APARTMENTTHREADED, 1);
175 
176     // Call CoInitializeEx and validate the results
177     HRESULT hr = CoInitializeEx(NULL, COINIT_APARTMENTTHREADED);
178     test_S_OK(hr, "CoInitializeEx");
179     spy.Check(1, 1, 0, 0);
180     spy2.Check(1, 1, 0, 0);
181 
182     // Calling CoInit twice with the same apartment makes it return S_FALSE but still increment count
183     spy.Expect(S_FALSE, COINIT_APARTMENTTHREADED, 2);
184     spy2.Expect(S_FALSE, COINIT_APARTMENTTHREADED, 2);
185 
186     hr = CoInitializeEx(NULL, COINIT_APARTMENTTHREADED);
187     test_HRES(hr, S_FALSE, "CoInitializeEx");
188     spy.Check(2, 2, 0, 0);
189     spy2.Check(2, 2, 0, 0);
190 
191     /* the order we registered the spies in is important here.
192         we have the second one to forcibly return S_OK, which makes the first spy see
193         S_OK instead of S_FALSE.. */
194     spy.Expect(S_OK, COINIT_APARTMENTTHREADED, 3);
195     spy2.m_AlwaysReturnOK = true;
196     spy2.Expect(S_FALSE, COINIT_APARTMENTTHREADED, 3);
197 
198     // and the S_OK also influences the returned value from CoInit.
199     hr = CoInitializeEx(NULL, COINIT_APARTMENTTHREADED);
200     test_S_OK(hr, "CoInitializeEx");
201     spy.Check(3, 3, 0, 0);
202     spy2.Check(3, 3, 0, 0);
203 
204     CoUninitialize();
205     spy.Check(3, 3, 1, 1);
206     spy2.Check(3, 3, 1, 1);
207 
208     spy.m_CurAptRefs = spy2.m_CurAptRefs = 2;
209 
210     CoUninitialize();
211     spy.Check(3, 3, 2, 2);
212     spy2.Check(3, 3, 2, 2);
213 
214     spy.m_CurAptRefs = spy2.m_CurAptRefs = 1;
215 
216     CoUninitialize();
217     spy.Check(3, 3, 3, 3);
218     spy2.Check(3, 3, 3, 3);
219 
220     spy.m_CurAptRefs = spy2.m_CurAptRefs = 0;
221 
222     CoUninitialize();
223     spy.Check(3, 3, 4, 4);
224     spy2.Check(3, 3, 4, 4);
225 }
226 
test_IInitializeSpy_switch_apt()227 void test_IInitializeSpy_switch_apt()
228 {
229     CTestSpy spy;
230 
231     spy.hr = pCoRegisterInitializeSpy(&spy, &spy.Cookie);
232     test_S_OK(spy.hr, "CoRegisterInitializeSpy");
233     test_ref(&spy, 1);
234 
235     spy.Expect(S_OK, COINIT_APARTMENTTHREADED, 1);
236 
237     HRESULT hr = CoInitializeEx(NULL, COINIT_APARTMENTTHREADED);
238     test_S_OK(hr, "CoInitializeEx");
239     spy.Check(1, 1, 0, 0);
240 
241     spy.Expect(RPC_E_CHANGED_MODE, COINIT_MULTITHREADED, 1);
242 
243     hr = CoInitializeEx(NULL, COINIT_MULTITHREADED);
244     test_HRES(hr, RPC_E_CHANGED_MODE, "CoInitializeEx");
245     spy.Check(2, 2, 0, 0);
246 
247 
248     CoUninitialize();
249     spy.Check(2, 2, 1, 1);
250 
251     spy.m_CurAptRefs = 0;
252 
253     CoUninitialize();
254     spy.Check(2, 2, 2, 2);
255 
256     CoUninitialize();
257     spy.Check(2, 2, 3, 3);
258 }
259 
test_IInitializeSpy_fail()260 void test_IInitializeSpy_fail()
261 {
262     CTestSpy spy;
263 
264     spy.m_FailQueryInterface = true;
265 
266     spy.hr = pCoRegisterInitializeSpy(&spy, &spy.Cookie);
267     test_HRES(spy.hr, E_NOINTERFACE, "Unexpected hr while registering invalid interface");
268     test_ref(&spy, 0);
269     ok(spy.Cookie.HighPart == 0xffffffff, "Unexpected Cookie.HighPart, expected 0xffffffff got: 0x%08lx\n", spy.Cookie.HighPart);
270     ok(spy.Cookie.LowPart == 0xffffffff, "Unexpected Cookie.HighPart, expected 0xffffffff got: 0x%08lx\n", spy.Cookie.LowPart);
271 
272     spy.Cookie.HighPart = spy.Cookie.LowPart = 0xffffffff;
273     HRESULT hr = pCoRevokeInitializeSpy(spy.Cookie);
274     test_HRES(hr, E_INVALIDARG, "Unexpected hr while unregistering invalid interface");
275     test_ref(&spy, 0);
276 
277     spy.Cookie.HighPart = spy.Cookie.LowPart = 0;
278     hr = pCoRevokeInitializeSpy(spy.Cookie);
279     test_HRES(hr, E_INVALIDARG, "Unexpected hr while unregistering invalid interface");
280     test_ref(&spy, 0);
281 
282     /* we should not crash here, just return E_NOINTERFACE
283         do note the Cookie is not even being touched at all, compared to calling this with an interface
284         that does not respond to IID_IInitializeSpy */
285     spy.Cookie.HighPart = spy.Cookie.LowPart = INVALID_VALUE;
286     hr = pCoRegisterInitializeSpy(NULL, &spy.Cookie);
287     test_HRES(spy.hr, E_NOINTERFACE, "Unexpected hr while registering NULL interface");
288     ok(spy.Cookie.HighPart == INVALID_VALUE, "Unexpected Cookie.HighPart, expected 0xdeadbeef got: %lx\n", spy.Cookie.HighPart);
289     ok(spy.Cookie.LowPart == INVALID_VALUE, "Unexpected Cookie.HighPart, expected 0xdeadbeef got: %lx\n", spy.Cookie.LowPart);
290 }
291 
test_IInitializeSpy_twice()292 void test_IInitializeSpy_twice()
293 {
294     CTestSpy spy;
295 
296     spy.hr = pCoRegisterInitializeSpy(&spy, &spy.Cookie);
297     test_S_OK(spy.hr, "CoRegisterInitializeSpy");
298     test_ref(&spy, 1);
299 
300     ULARGE_INTEGER Cookie = { { INVALID_VALUE, INVALID_VALUE } };
301     HRESULT hr = pCoRegisterInitializeSpy(&spy, &Cookie);
302     test_S_OK(hr, "CoRegisterInitializeSpy");
303     test_ref(&spy, 2);
304 
305     hr = pCoRevokeInitializeSpy(Cookie);
306     test_S_OK(hr, "CoRevokeInitializeSpy");
307     test_ref(&spy, 1);
308 }
309 
310 
START_TEST(initializespy)311 START_TEST(initializespy)
312 {
313     HMODULE ole32 = LoadLibraryA("ole32.dll");
314     pCoRegisterInitializeSpy = (pCoRegisterInitializeSpy_t)GetProcAddress(ole32, "CoRegisterInitializeSpy");
315     pCoRevokeInitializeSpy = (pCoRevokeInitializeSpy_t)GetProcAddress(ole32, "CoRevokeInitializeSpy");
316 
317     test_IInitializeSpy_register2();
318     test_IInitializeSpy_switch_apt();
319     test_IInitializeSpy_fail();
320     test_IInitializeSpy_twice();
321 }
322