1 /* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*-
2  * vim: sw=2 ts=2 et lcs=trail\:.,tab\:>~ :
3  * This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this
5  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #include "mozilla/ArrayUtils.h"
8 
9 #include "mozStorageSQLFunctions.h"
10 #include "nsTArray.h"
11 #include "nsUnicharUtils.h"
12 #include <algorithm>
13 
14 namespace mozilla {
15 namespace storage {
16 
17 ////////////////////////////////////////////////////////////////////////////////
18 //// Local Helper Functions
19 
20 namespace {
21 
22 /**
23  * Performs the LIKE comparison of a string against a pattern.  For more detail
24  * see http://www.sqlite.org/lang_expr.html#like.
25  *
26  * @param aPatternItr
27  *        An iterator at the start of the pattern to check for.
28  * @param aPatternEnd
29  *        An iterator at the end of the pattern to check for.
30  * @param aStringItr
31  *        An iterator at the start of the string to check for the pattern.
32  * @param aStringEnd
33  *        An iterator at the end of the string to check for the pattern.
34  * @param aEscapeChar
35  *        The character to use for escaping symbols in the pattern.
36  * @return 1 if the pattern is found, 0 otherwise.
37  */
likeCompare(nsAString::const_iterator aPatternItr,nsAString::const_iterator aPatternEnd,nsAString::const_iterator aStringItr,nsAString::const_iterator aStringEnd,char16_t aEscapeChar)38 int likeCompare(nsAString::const_iterator aPatternItr,
39                 nsAString::const_iterator aPatternEnd,
40                 nsAString::const_iterator aStringItr,
41                 nsAString::const_iterator aStringEnd, char16_t aEscapeChar) {
42   const char16_t MATCH_ALL('%');
43   const char16_t MATCH_ONE('_');
44 
45   bool lastWasEscape = false;
46   while (aPatternItr != aPatternEnd) {
47     /**
48      * What we do in here is take a look at each character from the input
49      * pattern, and do something with it.  There are 4 possibilities:
50      * 1) character is an un-escaped match-all character
51      * 2) character is an un-escaped match-one character
52      * 3) character is an un-escaped escape character
53      * 4) character is not any of the above
54      */
55     if (!lastWasEscape && *aPatternItr == MATCH_ALL) {
56       // CASE 1
57       /**
58        * Now we need to skip any MATCH_ALL or MATCH_ONE characters that follow a
59        * MATCH_ALL character.  For each MATCH_ONE character, skip one character
60        * in the pattern string.
61        */
62       while (*aPatternItr == MATCH_ALL || *aPatternItr == MATCH_ONE) {
63         if (*aPatternItr == MATCH_ONE) {
64           // If we've hit the end of the string we are testing, no match
65           if (aStringItr == aStringEnd) return 0;
66           aStringItr++;
67         }
68         aPatternItr++;
69       }
70 
71       // If we've hit the end of the pattern string, match
72       if (aPatternItr == aPatternEnd) return 1;
73 
74       while (aStringItr != aStringEnd) {
75         if (likeCompare(aPatternItr, aPatternEnd, aStringItr, aStringEnd,
76                         aEscapeChar)) {
77           // we've hit a match, so indicate this
78           return 1;
79         }
80         aStringItr++;
81       }
82 
83       // No match
84       return 0;
85     } else if (!lastWasEscape && *aPatternItr == MATCH_ONE) {
86       // CASE 2
87       if (aStringItr == aStringEnd) {
88         // If we've hit the end of the string we are testing, no match
89         return 0;
90       }
91       aStringItr++;
92       lastWasEscape = false;
93     } else if (!lastWasEscape && *aPatternItr == aEscapeChar) {
94       // CASE 3
95       lastWasEscape = true;
96     } else {
97       // CASE 4
98       if (::ToUpperCase(*aStringItr) != ::ToUpperCase(*aPatternItr)) {
99         // If we've hit a point where the strings don't match, there is no match
100         return 0;
101       }
102       aStringItr++;
103       lastWasEscape = false;
104     }
105 
106     aPatternItr++;
107   }
108 
109   return aStringItr == aStringEnd;
110 }
111 
112 /**
113  * Compute the Levenshtein Edit Distance between two strings.
114  *
115  * @param aStringS
116  *        a string
117  * @param aStringT
118  *        another string
119  * @param _result
120  *        an outparam that will receive the edit distance between the arguments
121  * @return a Sqlite result code, e.g. SQLITE_OK, SQLITE_NOMEM, etc.
122  */
levenshteinDistance(const nsAString & aStringS,const nsAString & aStringT,int * _result)123 int levenshteinDistance(const nsAString& aStringS, const nsAString& aStringT,
124                         int* _result) {
125   // Set the result to a non-sensical value in case we encounter an error.
126   *_result = -1;
127 
128   const uint32_t sLen = aStringS.Length();
129   const uint32_t tLen = aStringT.Length();
130 
131   if (sLen == 0) {
132     *_result = tLen;
133     return SQLITE_OK;
134   }
135   if (tLen == 0) {
136     *_result = sLen;
137     return SQLITE_OK;
138   }
139 
140   // Notionally, Levenshtein Distance is computed in a matrix.  If we
141   // assume s = "span" and t = "spam", the matrix would look like this:
142   //    s -->
143   //  t          s   p   a   n
144   //  |      0   1   2   3   4
145   //  V  s   1   *   *   *   *
146   //     p   2   *   *   *   *
147   //     a   3   *   *   *   *
148   //     m   4   *   *   *   *
149   //
150   // Note that the row width is sLen + 1 and the column height is tLen + 1,
151   // where sLen is the length of the string "s" and tLen is the length of "t".
152   // The first row and the first column are initialized as shown, and
153   // the algorithm computes the remaining cells row-by-row, and
154   // left-to-right within each row.  The computation only requires that
155   // we be able to see the current row and the previous one.
156 
157   // Allocate memory for two rows.
158   AutoTArray<int, nsAutoString::kStorageSize> row1;
159   AutoTArray<int, nsAutoString::kStorageSize> row2;
160 
161   // Declare the raw pointers that will actually be used to access the memory.
162   int* prevRow = row1.AppendElements(sLen + 1);
163   int* currRow = row2.AppendElements(sLen + 1);
164 
165   // Initialize the first row.
166   for (uint32_t i = 0; i <= sLen; i++) prevRow[i] = i;
167 
168   const char16_t* s = aStringS.BeginReading();
169   const char16_t* t = aStringT.BeginReading();
170 
171   // Compute the empty cells in the "matrix" row-by-row, starting with
172   // the second row.
173   for (uint32_t ti = 1; ti <= tLen; ti++) {
174     // Initialize the first cell in this row.
175     currRow[0] = ti;
176 
177     // Get the character from "t" that corresponds to this row.
178     const char16_t tch = t[ti - 1];
179 
180     // Compute the remaining cells in this row, left-to-right,
181     // starting at the second column (and first character of "s").
182     for (uint32_t si = 1; si <= sLen; si++) {
183       // Get the character from "s" that corresponds to this column,
184       // compare it to the t-character, and compute the "cost".
185       const char16_t sch = s[si - 1];
186       int cost = (sch == tch) ? 0 : 1;
187 
188       // ............ We want to calculate the value of cell "d" from
189       // ...ab....... the previously calculated (or initialized) cells
190       // ...cd....... "a", "b", and "c", where d = min(a', b', c').
191       // ............
192       int aPrime = prevRow[si - 1] + cost;
193       int bPrime = prevRow[si] + 1;
194       int cPrime = currRow[si - 1] + 1;
195       currRow[si] = std::min(aPrime, std::min(bPrime, cPrime));
196     }
197 
198     // Advance to the next row.  The current row becomes the previous
199     // row and we recycle the old previous row as the new current row.
200     // We don't need to re-initialize the new current row since we will
201     // rewrite all of its cells anyway.
202     int* oldPrevRow = prevRow;
203     prevRow = currRow;
204     currRow = oldPrevRow;
205   }
206 
207   // The final result is the value of the last cell in the last row.
208   // Note that that's now in the "previous" row, since we just swapped them.
209   *_result = prevRow[sLen];
210   return SQLITE_OK;
211 }
212 
213 // This struct is used only by registerFunctions below, but ISO C++98 forbids
214 // instantiating a template dependent on a locally-defined type.  Boo-urns!
215 struct Functions {
216   const char* zName;
217   int nArg;
218   int enc;
219   void* pContext;
220   void (*xFunc)(::sqlite3_context*, int, sqlite3_value**);
221 };
222 
223 }  // namespace
224 
225 ////////////////////////////////////////////////////////////////////////////////
226 //// Exposed Functions
227 
registerFunctions(sqlite3 * aDB)228 int registerFunctions(sqlite3* aDB) {
229   Functions functions[] = {
230       {"lower", 1, SQLITE_UTF16, 0, caseFunction},
231       {"lower", 1, SQLITE_UTF8, 0, caseFunction},
232       {"upper", 1, SQLITE_UTF16, (void*)1, caseFunction},
233       {"upper", 1, SQLITE_UTF8, (void*)1, caseFunction},
234 
235       {"like", 2, SQLITE_UTF16, 0, likeFunction},
236       {"like", 2, SQLITE_UTF8, 0, likeFunction},
237       {"like", 3, SQLITE_UTF16, 0, likeFunction},
238       {"like", 3, SQLITE_UTF8, 0, likeFunction},
239 
240       {"levenshteinDistance", 2, SQLITE_UTF16, 0, levenshteinDistanceFunction},
241       {"levenshteinDistance", 2, SQLITE_UTF8, 0, levenshteinDistanceFunction},
242 
243       {"utf16Length", 1, SQLITE_UTF16, 0, utf16LengthFunction},
244       {"utf16Length", 1, SQLITE_UTF8, 0, utf16LengthFunction},
245   };
246 
247   int rv = SQLITE_OK;
248   for (size_t i = 0; SQLITE_OK == rv && i < ArrayLength(functions); ++i) {
249     struct Functions* p = &functions[i];
250     rv = ::sqlite3_create_function(aDB, p->zName, p->nArg, p->enc, p->pContext,
251                                    p->xFunc, nullptr, nullptr);
252   }
253 
254   return rv;
255 }
256 
257 ////////////////////////////////////////////////////////////////////////////////
258 //// SQL Functions
259 
caseFunction(sqlite3_context * aCtx,int aArgc,sqlite3_value ** aArgv)260 void caseFunction(sqlite3_context* aCtx, int aArgc, sqlite3_value** aArgv) {
261   NS_ASSERTION(1 == aArgc, "Invalid number of arguments!");
262 
263   const char16_t* value =
264       static_cast<const char16_t*>(::sqlite3_value_text16(aArgv[0]));
265   nsAutoString data(value,
266                     ::sqlite3_value_bytes16(aArgv[0]) / sizeof(char16_t));
267   bool toUpper = ::sqlite3_user_data(aCtx) ? true : false;
268 
269   if (toUpper)
270     ::ToUpperCase(data);
271   else
272     ::ToLowerCase(data);
273 
274   // Set the result.
275   ::sqlite3_result_text16(aCtx, data.get(), data.Length() * sizeof(char16_t),
276                           SQLITE_TRANSIENT);
277 }
278 
279 /**
280  * This implements the like() SQL function.  This is used by the LIKE operator.
281  * The SQL statement 'A LIKE B' is implemented as 'like(B, A)', and if there is
282  * an escape character, say E, it is implemented as 'like(B, A, E)'.
283  */
likeFunction(sqlite3_context * aCtx,int aArgc,sqlite3_value ** aArgv)284 void likeFunction(sqlite3_context* aCtx, int aArgc, sqlite3_value** aArgv) {
285   NS_ASSERTION(2 == aArgc || 3 == aArgc, "Invalid number of arguments!");
286 
287   if (::sqlite3_value_bytes(aArgv[0]) > SQLITE_MAX_LIKE_PATTERN_LENGTH) {
288     ::sqlite3_result_error(aCtx, "LIKE or GLOB pattern too complex",
289                            SQLITE_TOOBIG);
290     return;
291   }
292 
293   if (!::sqlite3_value_text16(aArgv[0]) || !::sqlite3_value_text16(aArgv[1]))
294     return;
295 
296   const char16_t* a =
297       static_cast<const char16_t*>(::sqlite3_value_text16(aArgv[1]));
298   int aLen = ::sqlite3_value_bytes16(aArgv[1]) / sizeof(char16_t);
299   nsDependentString A(a, aLen);
300 
301   const char16_t* b =
302       static_cast<const char16_t*>(::sqlite3_value_text16(aArgv[0]));
303   int bLen = ::sqlite3_value_bytes16(aArgv[0]) / sizeof(char16_t);
304   nsDependentString B(b, bLen);
305   NS_ASSERTION(!B.IsEmpty(), "LIKE string must not be null!");
306 
307   char16_t E = 0;
308   if (3 == aArgc)
309     E = static_cast<const char16_t*>(::sqlite3_value_text16(aArgv[2]))[0];
310 
311   nsAString::const_iterator itrString, endString;
312   A.BeginReading(itrString);
313   A.EndReading(endString);
314   nsAString::const_iterator itrPattern, endPattern;
315   B.BeginReading(itrPattern);
316   B.EndReading(endPattern);
317   ::sqlite3_result_int(
318       aCtx, likeCompare(itrPattern, endPattern, itrString, endString, E));
319 }
320 
levenshteinDistanceFunction(sqlite3_context * aCtx,int aArgc,sqlite3_value ** aArgv)321 void levenshteinDistanceFunction(sqlite3_context* aCtx, int aArgc,
322                                  sqlite3_value** aArgv) {
323   NS_ASSERTION(2 == aArgc, "Invalid number of arguments!");
324 
325   // If either argument is a SQL NULL, then return SQL NULL.
326   if (::sqlite3_value_type(aArgv[0]) == SQLITE_NULL ||
327       ::sqlite3_value_type(aArgv[1]) == SQLITE_NULL) {
328     ::sqlite3_result_null(aCtx);
329     return;
330   }
331 
332   const char16_t* a =
333       static_cast<const char16_t*>(::sqlite3_value_text16(aArgv[0]));
334   int aLen = ::sqlite3_value_bytes16(aArgv[0]) / sizeof(char16_t);
335 
336   const char16_t* b =
337       static_cast<const char16_t*>(::sqlite3_value_text16(aArgv[1]));
338   int bLen = ::sqlite3_value_bytes16(aArgv[1]) / sizeof(char16_t);
339 
340   // Compute the Levenshtein Distance, and return the result (or error).
341   int distance = -1;
342   const nsDependentString A(a, aLen);
343   const nsDependentString B(b, bLen);
344   int status = levenshteinDistance(A, B, &distance);
345   if (status == SQLITE_OK) {
346     ::sqlite3_result_int(aCtx, distance);
347   } else if (status == SQLITE_NOMEM) {
348     ::sqlite3_result_error_nomem(aCtx);
349   } else {
350     ::sqlite3_result_error(aCtx, "User function returned error code", -1);
351   }
352 }
353 
utf16LengthFunction(sqlite3_context * aCtx,int aArgc,sqlite3_value ** aArgv)354 void utf16LengthFunction(sqlite3_context* aCtx, int aArgc,
355                          sqlite3_value** aArgv) {
356   NS_ASSERTION(1 == aArgc, "Invalid number of arguments!");
357 
358   int len = ::sqlite3_value_bytes16(aArgv[0]) / sizeof(char16_t);
359 
360   // Set the result.
361   ::sqlite3_result_int(aCtx, len);
362 }
363 
364 }  // namespace storage
365 }  // namespace mozilla
366