1 /* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*-
2  * vim: sw=2 ts=2 et lcs=trail\:.,tab\:>~ :
3  * This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this
5  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #ifndef storage_test_harness_h__
8 #define storage_test_harness_h__
9 
10 #include "gtest/gtest.h"
11 
12 #include "prthread.h"
13 #include "nsAppDirectoryServiceDefs.h"
14 #include "nsDirectoryServiceDefs.h"
15 #include "nsDirectoryServiceUtils.h"
16 #include "nsMemory.h"
17 #include "nsServiceManagerUtils.h"
18 #include "nsThreadUtils.h"
19 #include "mozilla/ReentrantMonitor.h"
20 
21 #include "mozIStorageService.h"
22 #include "mozIStorageConnection.h"
23 #include "mozIStorageStatementCallback.h"
24 #include "mozIStorageCompletionCallback.h"
25 #include "mozIStorageBindingParamsArray.h"
26 #include "mozIStorageBindingParams.h"
27 #include "mozIStorageAsyncStatement.h"
28 #include "mozIStorageStatement.h"
29 #include "mozIStoragePendingStatement.h"
30 #include "mozIStorageError.h"
31 #include "nsIInterfaceRequestorUtils.h"
32 #include "nsIEventTarget.h"
33 
34 #include "sqlite3.h"
35 
36 #define do_check_true(aCondition) \
37   EXPECT_TRUE(aCondition)
38 
39 #define do_check_false(aCondition) \
40   EXPECT_FALSE(aCondition)
41 
42 #define do_check_success(aResult) \
43   do_check_true(NS_SUCCEEDED(aResult))
44 
45 #define do_check_eq(aExpected, aActual) \
46   do_check_true(aExpected == aActual)
47 
48 #define do_check_ok(aInvoc) \
49   do_check_true((aInvoc) == SQLITE_OK)
50 
51 already_AddRefed<mozIStorageService>
getService()52 getService()
53 {
54   nsCOMPtr<mozIStorageService> ss =
55     do_CreateInstance("@mozilla.org/storage/service;1");
56   do_check_true(ss);
57   return ss.forget();
58 }
59 
60 already_AddRefed<mozIStorageConnection>
getMemoryDatabase()61 getMemoryDatabase()
62 {
63   nsCOMPtr<mozIStorageService> ss = getService();
64   nsCOMPtr<mozIStorageConnection> conn;
65   nsresult rv = ss->OpenSpecialDatabase("memory", getter_AddRefs(conn));
66   do_check_success(rv);
67   return conn.forget();
68 }
69 
70 already_AddRefed<mozIStorageConnection>
getDatabase()71 getDatabase()
72 {
73   nsCOMPtr<nsIFile> dbFile;
74   (void)NS_GetSpecialDirectory(NS_APP_USER_PROFILE_50_DIR,
75                                getter_AddRefs(dbFile));
76   NS_ASSERTION(dbFile, "The directory doesn't exists?!");
77 
78   nsresult rv = dbFile->Append(NS_LITERAL_STRING("storage_test_db.sqlite"));
79   do_check_success(rv);
80 
81   nsCOMPtr<mozIStorageService> ss = getService();
82   nsCOMPtr<mozIStorageConnection> conn;
83   rv = ss->OpenDatabase(dbFile, getter_AddRefs(conn));
84   do_check_success(rv);
85   return conn.forget();
86 }
87 
88 
89 class AsyncStatementSpinner : public mozIStorageStatementCallback
90                             , public mozIStorageCompletionCallback
91 {
92 public:
93   NS_DECL_ISUPPORTS
94   NS_DECL_MOZISTORAGESTATEMENTCALLBACK
95   NS_DECL_MOZISTORAGECOMPLETIONCALLBACK
96 
97   AsyncStatementSpinner();
98 
99   void SpinUntilCompleted();
100 
101   uint16_t completionReason;
102 
103 protected:
~AsyncStatementSpinner()104   virtual ~AsyncStatementSpinner() {}
105   volatile bool mCompleted;
106 };
107 
NS_IMPL_ISUPPORTS(AsyncStatementSpinner,mozIStorageStatementCallback,mozIStorageCompletionCallback)108 NS_IMPL_ISUPPORTS(AsyncStatementSpinner,
109                   mozIStorageStatementCallback,
110                   mozIStorageCompletionCallback)
111 
112 AsyncStatementSpinner::AsyncStatementSpinner()
113 : completionReason(0)
114 , mCompleted(false)
115 {
116 }
117 
118 NS_IMETHODIMP
HandleResult(mozIStorageResultSet * aResultSet)119 AsyncStatementSpinner::HandleResult(mozIStorageResultSet *aResultSet)
120 {
121   return NS_OK;
122 }
123 
124 NS_IMETHODIMP
HandleError(mozIStorageError * aError)125 AsyncStatementSpinner::HandleError(mozIStorageError *aError)
126 {
127   int32_t result;
128   nsresult rv = aError->GetResult(&result);
129   NS_ENSURE_SUCCESS(rv, rv);
130   nsAutoCString message;
131   rv = aError->GetMessage(message);
132   NS_ENSURE_SUCCESS(rv, rv);
133 
134   nsAutoCString warnMsg;
135   warnMsg.AppendLiteral("An error occurred while executing an async statement: ");
136   warnMsg.AppendInt(result);
137   warnMsg.Append(' ');
138   warnMsg.Append(message);
139   NS_WARNING(warnMsg.get());
140 
141   return NS_OK;
142 }
143 
144 NS_IMETHODIMP
HandleCompletion(uint16_t aReason)145 AsyncStatementSpinner::HandleCompletion(uint16_t aReason)
146 {
147   completionReason = aReason;
148   mCompleted = true;
149   return NS_OK;
150 }
151 
152 NS_IMETHODIMP
Complete(nsresult,nsISupports *)153 AsyncStatementSpinner::Complete(nsresult, nsISupports*)
154 {
155   mCompleted = true;
156   return NS_OK;
157 }
158 
SpinUntilCompleted()159 void AsyncStatementSpinner::SpinUntilCompleted()
160 {
161   nsCOMPtr<nsIThread> thread(::do_GetCurrentThread());
162   nsresult rv = NS_OK;
163   bool processed = true;
164   while (!mCompleted && NS_SUCCEEDED(rv)) {
165     rv = thread->ProcessNextEvent(true, &processed);
166   }
167 }
168 
169 #define NS_DECL_ASYNCSTATEMENTSPINNER \
170   NS_IMETHOD HandleResult(mozIStorageResultSet *aResultSet) override;
171 
172 ////////////////////////////////////////////////////////////////////////////////
173 //// Async Helpers
174 
175 /**
176  * Execute an async statement, blocking the main thread until we get the
177  * callback completion notification.
178  */
179 void
blocking_async_execute(mozIStorageBaseStatement * stmt)180 blocking_async_execute(mozIStorageBaseStatement *stmt)
181 {
182   RefPtr<AsyncStatementSpinner> spinner(new AsyncStatementSpinner());
183 
184   nsCOMPtr<mozIStoragePendingStatement> pendy;
185   (void)stmt->ExecuteAsync(spinner, getter_AddRefs(pendy));
186   spinner->SpinUntilCompleted();
187 }
188 
189 /**
190  * Invoke AsyncClose on the given connection, blocking the main thread until we
191  * get the completion notification.
192  */
193 void
blocking_async_close(mozIStorageConnection * db)194 blocking_async_close(mozIStorageConnection *db)
195 {
196   RefPtr<AsyncStatementSpinner> spinner(new AsyncStatementSpinner());
197 
198   db->AsyncClose(spinner);
199   spinner->SpinUntilCompleted();
200 }
201 
202 ////////////////////////////////////////////////////////////////////////////////
203 //// Mutex Watching
204 
205 /**
206  * Verify that mozIStorageAsyncStatement's life-cycle never triggers a mutex on
207  * the caller (generally main) thread.  We do this by decorating the sqlite
208  * mutex logic with our own code that checks what thread it is being invoked on
209  * and sets a flag if it is invoked on the main thread.  We are able to easily
210  * decorate the SQLite mutex logic because SQLite allows us to retrieve the
211  * current function pointers being used and then provide a new set.
212  */
213 
214 sqlite3_mutex_methods orig_mutex_methods;
215 sqlite3_mutex_methods wrapped_mutex_methods;
216 
217 bool mutex_used_on_watched_thread = false;
218 PRThread *watched_thread = nullptr;
219 /**
220  * Ugly hack to let us figure out what a connection's async thread is.  If we
221  * were MOZILLA_INTERNAL_API and linked as such we could just include
222  * mozStorageConnection.h and just ask Connection directly.  But that turns out
223  * poorly.
224  *
225  * When the thread a mutex is invoked on isn't watched_thread we save it to this
226  * variable.
227  */
228 PRThread *last_non_watched_thread = nullptr;
229 
230 /**
231  * Set a flag if the mutex is used on the thread we are watching, but always
232  * call the real mutex function.
233  */
wrapped_MutexEnter(sqlite3_mutex * mutex)234 extern "C" void wrapped_MutexEnter(sqlite3_mutex *mutex)
235 {
236   PRThread *curThread = ::PR_GetCurrentThread();
237   if (curThread == watched_thread)
238     mutex_used_on_watched_thread = true;
239   else
240     last_non_watched_thread = curThread;
241   orig_mutex_methods.xMutexEnter(mutex);
242 }
243 
wrapped_MutexTry(sqlite3_mutex * mutex)244 extern "C" int wrapped_MutexTry(sqlite3_mutex *mutex)
245 {
246   if (::PR_GetCurrentThread() == watched_thread)
247     mutex_used_on_watched_thread = true;
248   return orig_mutex_methods.xMutexTry(mutex);
249 }
250 
251 class HookSqliteMutex
252 {
253 public:
HookSqliteMutex()254   HookSqliteMutex()
255   {
256     // We need to initialize and teardown SQLite to get it to set up the
257     // default mutex handlers for us so we can steal them and wrap them.
258     do_check_ok(sqlite3_initialize());
259     do_check_ok(sqlite3_shutdown());
260     do_check_ok(::sqlite3_config(SQLITE_CONFIG_GETMUTEX, &orig_mutex_methods));
261     do_check_ok(::sqlite3_config(SQLITE_CONFIG_GETMUTEX, &wrapped_mutex_methods));
262     wrapped_mutex_methods.xMutexEnter = wrapped_MutexEnter;
263     wrapped_mutex_methods.xMutexTry = wrapped_MutexTry;
264     do_check_ok(::sqlite3_config(SQLITE_CONFIG_MUTEX, &wrapped_mutex_methods));
265   }
266 
~HookSqliteMutex()267   ~HookSqliteMutex()
268   {
269     do_check_ok(sqlite3_shutdown());
270     do_check_ok(::sqlite3_config(SQLITE_CONFIG_MUTEX, &orig_mutex_methods));
271     do_check_ok(sqlite3_initialize());
272   }
273 };
274 
275 /**
276  * Call to clear the watch state and to set the watching against this thread.
277  *
278  * Check |mutex_used_on_watched_thread| to see if the mutex has fired since
279  * this method was last called.  Since we're talking about the current thread,
280  * there are no race issues to be concerned about
281  */
watch_for_mutex_use_on_this_thread()282 void watch_for_mutex_use_on_this_thread()
283 {
284   watched_thread = ::PR_GetCurrentThread();
285   mutex_used_on_watched_thread = false;
286 }
287 
288 
289 ////////////////////////////////////////////////////////////////////////////////
290 //// Thread Wedgers
291 
292 /**
293  * A runnable that blocks until code on another thread invokes its unwedge
294  * method.  By dispatching this to a thread you can ensure that no subsequent
295  * runnables dispatched to the thread will execute until you invoke unwedge.
296  *
297  * The wedger is self-dispatching, just construct it with its target.
298  */
299 class ThreadWedger : public mozilla::Runnable
300 {
301 public:
ThreadWedger(nsIEventTarget * aTarget)302   explicit ThreadWedger(nsIEventTarget* aTarget)
303     : mozilla::Runnable("ThreadWedger")
304     , mReentrantMonitor("thread wedger")
305     , unwedged(false)
306   {
307     aTarget->Dispatch(this, aTarget->NS_DISPATCH_NORMAL);
308   }
309 
Run()310   NS_IMETHOD Run() override
311   {
312     mozilla::ReentrantMonitorAutoEnter automon(mReentrantMonitor);
313 
314     if (!unwedged)
315       automon.Wait();
316 
317     return NS_OK;
318   }
319 
unwedge()320   void unwedge()
321   {
322     mozilla::ReentrantMonitorAutoEnter automon(mReentrantMonitor);
323     unwedged = true;
324     automon.Notify();
325   }
326 
327 private:
328   mozilla::ReentrantMonitor mReentrantMonitor;
329   bool unwedged;
330 };
331 
332 ////////////////////////////////////////////////////////////////////////////////
333 //// Async Helpers
334 
335 /**
336  * A horrible hack to figure out what the connection's async thread is.  By
337  * creating a statement and async dispatching we can tell from the mutex who
338  * is the async thread, PRThread style.  Then we map that to an nsIThread.
339  */
340 already_AddRefed<nsIThread>
get_conn_async_thread(mozIStorageConnection * db)341 get_conn_async_thread(mozIStorageConnection *db)
342 {
343   // Make sure we are tracking the current thread as the watched thread
344   watch_for_mutex_use_on_this_thread();
345 
346   // - statement with nothing to bind
347   nsCOMPtr<mozIStorageAsyncStatement> stmt;
348   db->CreateAsyncStatement(
349     NS_LITERAL_CSTRING("SELECT 1"),
350     getter_AddRefs(stmt));
351   blocking_async_execute(stmt);
352   stmt->Finalize();
353 
354   nsCOMPtr<nsIThreadManager> threadMan =
355     do_GetService("@mozilla.org/thread-manager;1");
356   nsCOMPtr<nsIThread> asyncThread;
357   threadMan->GetThreadFromPRThread(last_non_watched_thread,
358                                    getter_AddRefs(asyncThread));
359 
360   // Additionally, check that the thread we get as the background thread is the
361   // same one as the one we report from getInterface.
362   nsCOMPtr<nsIEventTarget> target = do_GetInterface(db);
363   nsCOMPtr<nsIThread> allegedAsyncThread = do_QueryInterface(target);
364   PRThread *allegedPRThread;
365   (void)allegedAsyncThread->GetPRThread(&allegedPRThread);
366   do_check_eq(allegedPRThread, last_non_watched_thread);
367   return asyncThread.forget();
368 }
369 
370 #endif // storage_test_harness_h__
371 
372