1 /* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*-
2  * vim: sw=2 ts=2 et lcs=trail\:.,tab\:>~ :
3  * This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this
5  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #include "mozilla/ArrayUtils.h"
8 
9 #include "mozStorageSQLFunctions.h"
10 #include "nsUnicharUtils.h"
11 #include <algorithm>
12 
13 namespace mozilla {
14 namespace storage {
15 
16 ////////////////////////////////////////////////////////////////////////////////
17 //// Local Helper Functions
18 
19 namespace {
20 
21 /**
22  * Performs the LIKE comparison of a string against a pattern.  For more detail
23  * see http://www.sqlite.org/lang_expr.html#like.
24  *
25  * @param aPatternItr
26  *        An iterator at the start of the pattern to check for.
27  * @param aPatternEnd
28  *        An iterator at the end of the pattern to check for.
29  * @param aStringItr
30  *        An iterator at the start of the string to check for the pattern.
31  * @param aStringEnd
32  *        An iterator at the end of the string to check for the pattern.
33  * @param aEscapeChar
34  *        The character to use for escaping symbols in the pattern.
35  * @return 1 if the pattern is found, 0 otherwise.
36  */
likeCompare(nsAString::const_iterator aPatternItr,nsAString::const_iterator aPatternEnd,nsAString::const_iterator aStringItr,nsAString::const_iterator aStringEnd,char16_t aEscapeChar)37 int likeCompare(nsAString::const_iterator aPatternItr,
38                 nsAString::const_iterator aPatternEnd,
39                 nsAString::const_iterator aStringItr,
40                 nsAString::const_iterator aStringEnd, char16_t aEscapeChar) {
41   const char16_t MATCH_ALL('%');
42   const char16_t MATCH_ONE('_');
43 
44   bool lastWasEscape = false;
45   while (aPatternItr != aPatternEnd) {
46     /**
47      * What we do in here is take a look at each character from the input
48      * pattern, and do something with it.  There are 4 possibilities:
49      * 1) character is an un-escaped match-all character
50      * 2) character is an un-escaped match-one character
51      * 3) character is an un-escaped escape character
52      * 4) character is not any of the above
53      */
54     if (!lastWasEscape && *aPatternItr == MATCH_ALL) {
55       // CASE 1
56       /**
57        * Now we need to skip any MATCH_ALL or MATCH_ONE characters that follow a
58        * MATCH_ALL character.  For each MATCH_ONE character, skip one character
59        * in the pattern string.
60        */
61       while (*aPatternItr == MATCH_ALL || *aPatternItr == MATCH_ONE) {
62         if (*aPatternItr == MATCH_ONE) {
63           // If we've hit the end of the string we are testing, no match
64           if (aStringItr == aStringEnd) return 0;
65           aStringItr++;
66         }
67         aPatternItr++;
68       }
69 
70       // If we've hit the end of the pattern string, match
71       if (aPatternItr == aPatternEnd) return 1;
72 
73       while (aStringItr != aStringEnd) {
74         if (likeCompare(aPatternItr, aPatternEnd, aStringItr, aStringEnd,
75                         aEscapeChar)) {
76           // we've hit a match, so indicate this
77           return 1;
78         }
79         aStringItr++;
80       }
81 
82       // No match
83       return 0;
84     } else if (!lastWasEscape && *aPatternItr == MATCH_ONE) {
85       // CASE 2
86       if (aStringItr == aStringEnd) {
87         // If we've hit the end of the string we are testing, no match
88         return 0;
89       }
90       aStringItr++;
91       lastWasEscape = false;
92     } else if (!lastWasEscape && *aPatternItr == aEscapeChar) {
93       // CASE 3
94       lastWasEscape = true;
95     } else {
96       // CASE 4
97       if (::ToUpperCase(*aStringItr) != ::ToUpperCase(*aPatternItr)) {
98         // If we've hit a point where the strings don't match, there is no match
99         return 0;
100       }
101       aStringItr++;
102       lastWasEscape = false;
103     }
104 
105     aPatternItr++;
106   }
107 
108   return aStringItr == aStringEnd;
109 }
110 
111 /**
112  * Compute the Levenshtein Edit Distance between two strings.
113  *
114  * @param aStringS
115  *        a string
116  * @param aStringT
117  *        another string
118  * @param _result
119  *        an outparam that will receive the edit distance between the arguments
120  * @return a Sqlite result code, e.g. SQLITE_OK, SQLITE_NOMEM, etc.
121  */
levenshteinDistance(const nsAString & aStringS,const nsAString & aStringT,int * _result)122 int levenshteinDistance(const nsAString &aStringS, const nsAString &aStringT,
123                         int *_result) {
124   // Set the result to a non-sensical value in case we encounter an error.
125   *_result = -1;
126 
127   const uint32_t sLen = aStringS.Length();
128   const uint32_t tLen = aStringT.Length();
129 
130   if (sLen == 0) {
131     *_result = tLen;
132     return SQLITE_OK;
133   }
134   if (tLen == 0) {
135     *_result = sLen;
136     return SQLITE_OK;
137   }
138 
139   // Notionally, Levenshtein Distance is computed in a matrix.  If we
140   // assume s = "span" and t = "spam", the matrix would look like this:
141   //    s -->
142   //  t          s   p   a   n
143   //  |      0   1   2   3   4
144   //  V  s   1   *   *   *   *
145   //     p   2   *   *   *   *
146   //     a   3   *   *   *   *
147   //     m   4   *   *   *   *
148   //
149   // Note that the row width is sLen + 1 and the column height is tLen + 1,
150   // where sLen is the length of the string "s" and tLen is the length of "t".
151   // The first row and the first column are initialized as shown, and
152   // the algorithm computes the remaining cells row-by-row, and
153   // left-to-right within each row.  The computation only requires that
154   // we be able to see the current row and the previous one.
155 
156   // Allocate memory for two rows.
157   AutoTArray<int, nsAutoString::kStorageSize> row1;
158   AutoTArray<int, nsAutoString::kStorageSize> row2;
159 
160   // Declare the raw pointers that will actually be used to access the memory.
161   int *prevRow = row1.AppendElements(sLen + 1);
162   int *currRow = row2.AppendElements(sLen + 1);
163 
164   // Initialize the first row.
165   for (uint32_t i = 0; i <= sLen; i++) prevRow[i] = i;
166 
167   const char16_t *s = aStringS.BeginReading();
168   const char16_t *t = aStringT.BeginReading();
169 
170   // Compute the empty cells in the "matrix" row-by-row, starting with
171   // the second row.
172   for (uint32_t ti = 1; ti <= tLen; ti++) {
173     // Initialize the first cell in this row.
174     currRow[0] = ti;
175 
176     // Get the character from "t" that corresponds to this row.
177     const char16_t tch = t[ti - 1];
178 
179     // Compute the remaining cells in this row, left-to-right,
180     // starting at the second column (and first character of "s").
181     for (uint32_t si = 1; si <= sLen; si++) {
182       // Get the character from "s" that corresponds to this column,
183       // compare it to the t-character, and compute the "cost".
184       const char16_t sch = s[si - 1];
185       int cost = (sch == tch) ? 0 : 1;
186 
187       // ............ We want to calculate the value of cell "d" from
188       // ...ab....... the previously calculated (or initialized) cells
189       // ...cd....... "a", "b", and "c", where d = min(a', b', c').
190       // ............
191       int aPrime = prevRow[si - 1] + cost;
192       int bPrime = prevRow[si] + 1;
193       int cPrime = currRow[si - 1] + 1;
194       currRow[si] = std::min(aPrime, std::min(bPrime, cPrime));
195     }
196 
197     // Advance to the next row.  The current row becomes the previous
198     // row and we recycle the old previous row as the new current row.
199     // We don't need to re-initialize the new current row since we will
200     // rewrite all of its cells anyway.
201     int *oldPrevRow = prevRow;
202     prevRow = currRow;
203     currRow = oldPrevRow;
204   }
205 
206   // The final result is the value of the last cell in the last row.
207   // Note that that's now in the "previous" row, since we just swapped them.
208   *_result = prevRow[sLen];
209   return SQLITE_OK;
210 }
211 
212 // This struct is used only by registerFunctions below, but ISO C++98 forbids
213 // instantiating a template dependent on a locally-defined type.  Boo-urns!
214 struct Functions {
215   const char *zName;
216   int nArg;
217   int enc;
218   void *pContext;
219   void (*xFunc)(::sqlite3_context *, int, sqlite3_value **);
220 };
221 
222 }  // namespace
223 
224 ////////////////////////////////////////////////////////////////////////////////
225 //// Exposed Functions
226 
registerFunctions(sqlite3 * aDB)227 int registerFunctions(sqlite3 *aDB) {
228   Functions functions[] = {
229       {"lower", 1, SQLITE_UTF16, 0, caseFunction},
230       {"lower", 1, SQLITE_UTF8, 0, caseFunction},
231       {"upper", 1, SQLITE_UTF16, (void *)1, caseFunction},
232       {"upper", 1, SQLITE_UTF8, (void *)1, caseFunction},
233 
234       {"like", 2, SQLITE_UTF16, 0, likeFunction},
235       {"like", 2, SQLITE_UTF8, 0, likeFunction},
236       {"like", 3, SQLITE_UTF16, 0, likeFunction},
237       {"like", 3, SQLITE_UTF8, 0, likeFunction},
238 
239       {"levenshteinDistance", 2, SQLITE_UTF16, 0, levenshteinDistanceFunction},
240       {"levenshteinDistance", 2, SQLITE_UTF8, 0, levenshteinDistanceFunction},
241   };
242 
243   int rv = SQLITE_OK;
244   for (size_t i = 0; SQLITE_OK == rv && i < ArrayLength(functions); ++i) {
245     struct Functions *p = &functions[i];
246     rv = ::sqlite3_create_function(aDB, p->zName, p->nArg, p->enc, p->pContext,
247                                    p->xFunc, nullptr, nullptr);
248   }
249 
250   return rv;
251 }
252 
253 ////////////////////////////////////////////////////////////////////////////////
254 //// SQL Functions
255 
caseFunction(sqlite3_context * aCtx,int aArgc,sqlite3_value ** aArgv)256 void caseFunction(sqlite3_context *aCtx, int aArgc, sqlite3_value **aArgv) {
257   NS_ASSERTION(1 == aArgc, "Invalid number of arguments!");
258 
259   nsAutoString data(
260       static_cast<const char16_t *>(::sqlite3_value_text16(aArgv[0])));
261   bool toUpper = ::sqlite3_user_data(aCtx) ? true : false;
262 
263   if (toUpper)
264     ::ToUpperCase(data);
265   else
266     ::ToLowerCase(data);
267 
268   // Set the result.
269   ::sqlite3_result_text16(aCtx, data.get(), -1, SQLITE_TRANSIENT);
270 }
271 
272 /**
273  * This implements the like() SQL function.  This is used by the LIKE operator.
274  * The SQL statement 'A LIKE B' is implemented as 'like(B, A)', and if there is
275  * an escape character, say E, it is implemented as 'like(B, A, E)'.
276  */
likeFunction(sqlite3_context * aCtx,int aArgc,sqlite3_value ** aArgv)277 void likeFunction(sqlite3_context *aCtx, int aArgc, sqlite3_value **aArgv) {
278   NS_ASSERTION(2 == aArgc || 3 == aArgc, "Invalid number of arguments!");
279 
280   if (::sqlite3_value_bytes(aArgv[0]) > SQLITE_MAX_LIKE_PATTERN_LENGTH) {
281     ::sqlite3_result_error(aCtx, "LIKE or GLOB pattern too complex",
282                            SQLITE_TOOBIG);
283     return;
284   }
285 
286   if (!::sqlite3_value_text16(aArgv[0]) || !::sqlite3_value_text16(aArgv[1]))
287     return;
288 
289   nsDependentString A(
290       static_cast<const char16_t *>(::sqlite3_value_text16(aArgv[1])));
291   nsDependentString B(
292       static_cast<const char16_t *>(::sqlite3_value_text16(aArgv[0])));
293   NS_ASSERTION(!B.IsEmpty(), "LIKE string must not be null!");
294 
295   char16_t E = 0;
296   if (3 == aArgc)
297     E = static_cast<const char16_t *>(::sqlite3_value_text16(aArgv[2]))[0];
298 
299   nsAString::const_iterator itrString, endString;
300   A.BeginReading(itrString);
301   A.EndReading(endString);
302   nsAString::const_iterator itrPattern, endPattern;
303   B.BeginReading(itrPattern);
304   B.EndReading(endPattern);
305   ::sqlite3_result_int(
306       aCtx, likeCompare(itrPattern, endPattern, itrString, endString, E));
307 }
308 
levenshteinDistanceFunction(sqlite3_context * aCtx,int aArgc,sqlite3_value ** aArgv)309 void levenshteinDistanceFunction(sqlite3_context *aCtx, int aArgc,
310                                  sqlite3_value **aArgv) {
311   NS_ASSERTION(2 == aArgc, "Invalid number of arguments!");
312 
313   // If either argument is a SQL NULL, then return SQL NULL.
314   if (::sqlite3_value_type(aArgv[0]) == SQLITE_NULL ||
315       ::sqlite3_value_type(aArgv[1]) == SQLITE_NULL) {
316     ::sqlite3_result_null(aCtx);
317     return;
318   }
319 
320   int aLen = ::sqlite3_value_bytes16(aArgv[0]) / sizeof(char16_t);
321   const char16_t *a =
322       static_cast<const char16_t *>(::sqlite3_value_text16(aArgv[0]));
323 
324   int bLen = ::sqlite3_value_bytes16(aArgv[1]) / sizeof(char16_t);
325   const char16_t *b =
326       static_cast<const char16_t *>(::sqlite3_value_text16(aArgv[1]));
327 
328   // Compute the Levenshtein Distance, and return the result (or error).
329   int distance = -1;
330   const nsDependentString A(a, aLen);
331   const nsDependentString B(b, bLen);
332   int status = levenshteinDistance(A, B, &distance);
333   if (status == SQLITE_OK) {
334     ::sqlite3_result_int(aCtx, distance);
335   } else if (status == SQLITE_NOMEM) {
336     ::sqlite3_result_error_nomem(aCtx);
337   } else {
338     ::sqlite3_result_error(aCtx, "User function returned error code", -1);
339   }
340 }
341 
342 }  // namespace storage
343 }  // namespace mozilla
344