1 /* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*-
2 * vim: sw=2 ts=2 et lcs=trail\:.,tab\:>~ :
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6
7 #include "mozilla/ArrayUtils.h"
8
9 #include "mozStorageSQLFunctions.h"
10 #include "nsUnicharUtils.h"
11 #include <algorithm>
12
13 namespace mozilla {
14 namespace storage {
15
16 ////////////////////////////////////////////////////////////////////////////////
17 //// Local Helper Functions
18
19 namespace {
20
21 /**
22 * Performs the LIKE comparison of a string against a pattern. For more detail
23 * see http://www.sqlite.org/lang_expr.html#like.
24 *
25 * @param aPatternItr
26 * An iterator at the start of the pattern to check for.
27 * @param aPatternEnd
28 * An iterator at the end of the pattern to check for.
29 * @param aStringItr
30 * An iterator at the start of the string to check for the pattern.
31 * @param aStringEnd
32 * An iterator at the end of the string to check for the pattern.
33 * @param aEscapeChar
34 * The character to use for escaping symbols in the pattern.
35 * @return 1 if the pattern is found, 0 otherwise.
36 */
likeCompare(nsAString::const_iterator aPatternItr,nsAString::const_iterator aPatternEnd,nsAString::const_iterator aStringItr,nsAString::const_iterator aStringEnd,char16_t aEscapeChar)37 int likeCompare(nsAString::const_iterator aPatternItr,
38 nsAString::const_iterator aPatternEnd,
39 nsAString::const_iterator aStringItr,
40 nsAString::const_iterator aStringEnd, char16_t aEscapeChar) {
41 const char16_t MATCH_ALL('%');
42 const char16_t MATCH_ONE('_');
43
44 bool lastWasEscape = false;
45 while (aPatternItr != aPatternEnd) {
46 /**
47 * What we do in here is take a look at each character from the input
48 * pattern, and do something with it. There are 4 possibilities:
49 * 1) character is an un-escaped match-all character
50 * 2) character is an un-escaped match-one character
51 * 3) character is an un-escaped escape character
52 * 4) character is not any of the above
53 */
54 if (!lastWasEscape && *aPatternItr == MATCH_ALL) {
55 // CASE 1
56 /**
57 * Now we need to skip any MATCH_ALL or MATCH_ONE characters that follow a
58 * MATCH_ALL character. For each MATCH_ONE character, skip one character
59 * in the pattern string.
60 */
61 while (*aPatternItr == MATCH_ALL || *aPatternItr == MATCH_ONE) {
62 if (*aPatternItr == MATCH_ONE) {
63 // If we've hit the end of the string we are testing, no match
64 if (aStringItr == aStringEnd) return 0;
65 aStringItr++;
66 }
67 aPatternItr++;
68 }
69
70 // If we've hit the end of the pattern string, match
71 if (aPatternItr == aPatternEnd) return 1;
72
73 while (aStringItr != aStringEnd) {
74 if (likeCompare(aPatternItr, aPatternEnd, aStringItr, aStringEnd,
75 aEscapeChar)) {
76 // we've hit a match, so indicate this
77 return 1;
78 }
79 aStringItr++;
80 }
81
82 // No match
83 return 0;
84 } else if (!lastWasEscape && *aPatternItr == MATCH_ONE) {
85 // CASE 2
86 if (aStringItr == aStringEnd) {
87 // If we've hit the end of the string we are testing, no match
88 return 0;
89 }
90 aStringItr++;
91 lastWasEscape = false;
92 } else if (!lastWasEscape && *aPatternItr == aEscapeChar) {
93 // CASE 3
94 lastWasEscape = true;
95 } else {
96 // CASE 4
97 if (::ToUpperCase(*aStringItr) != ::ToUpperCase(*aPatternItr)) {
98 // If we've hit a point where the strings don't match, there is no match
99 return 0;
100 }
101 aStringItr++;
102 lastWasEscape = false;
103 }
104
105 aPatternItr++;
106 }
107
108 return aStringItr == aStringEnd;
109 }
110
111 /**
112 * Compute the Levenshtein Edit Distance between two strings.
113 *
114 * @param aStringS
115 * a string
116 * @param aStringT
117 * another string
118 * @param _result
119 * an outparam that will receive the edit distance between the arguments
120 * @return a Sqlite result code, e.g. SQLITE_OK, SQLITE_NOMEM, etc.
121 */
levenshteinDistance(const nsAString & aStringS,const nsAString & aStringT,int * _result)122 int levenshteinDistance(const nsAString &aStringS, const nsAString &aStringT,
123 int *_result) {
124 // Set the result to a non-sensical value in case we encounter an error.
125 *_result = -1;
126
127 const uint32_t sLen = aStringS.Length();
128 const uint32_t tLen = aStringT.Length();
129
130 if (sLen == 0) {
131 *_result = tLen;
132 return SQLITE_OK;
133 }
134 if (tLen == 0) {
135 *_result = sLen;
136 return SQLITE_OK;
137 }
138
139 // Notionally, Levenshtein Distance is computed in a matrix. If we
140 // assume s = "span" and t = "spam", the matrix would look like this:
141 // s -->
142 // t s p a n
143 // | 0 1 2 3 4
144 // V s 1 * * * *
145 // p 2 * * * *
146 // a 3 * * * *
147 // m 4 * * * *
148 //
149 // Note that the row width is sLen + 1 and the column height is tLen + 1,
150 // where sLen is the length of the string "s" and tLen is the length of "t".
151 // The first row and the first column are initialized as shown, and
152 // the algorithm computes the remaining cells row-by-row, and
153 // left-to-right within each row. The computation only requires that
154 // we be able to see the current row and the previous one.
155
156 // Allocate memory for two rows.
157 AutoTArray<int, nsAutoString::kStorageSize> row1;
158 AutoTArray<int, nsAutoString::kStorageSize> row2;
159
160 // Declare the raw pointers that will actually be used to access the memory.
161 int *prevRow = row1.AppendElements(sLen + 1);
162 int *currRow = row2.AppendElements(sLen + 1);
163
164 // Initialize the first row.
165 for (uint32_t i = 0; i <= sLen; i++) prevRow[i] = i;
166
167 const char16_t *s = aStringS.BeginReading();
168 const char16_t *t = aStringT.BeginReading();
169
170 // Compute the empty cells in the "matrix" row-by-row, starting with
171 // the second row.
172 for (uint32_t ti = 1; ti <= tLen; ti++) {
173 // Initialize the first cell in this row.
174 currRow[0] = ti;
175
176 // Get the character from "t" that corresponds to this row.
177 const char16_t tch = t[ti - 1];
178
179 // Compute the remaining cells in this row, left-to-right,
180 // starting at the second column (and first character of "s").
181 for (uint32_t si = 1; si <= sLen; si++) {
182 // Get the character from "s" that corresponds to this column,
183 // compare it to the t-character, and compute the "cost".
184 const char16_t sch = s[si - 1];
185 int cost = (sch == tch) ? 0 : 1;
186
187 // ............ We want to calculate the value of cell "d" from
188 // ...ab....... the previously calculated (or initialized) cells
189 // ...cd....... "a", "b", and "c", where d = min(a', b', c').
190 // ............
191 int aPrime = prevRow[si - 1] + cost;
192 int bPrime = prevRow[si] + 1;
193 int cPrime = currRow[si - 1] + 1;
194 currRow[si] = std::min(aPrime, std::min(bPrime, cPrime));
195 }
196
197 // Advance to the next row. The current row becomes the previous
198 // row and we recycle the old previous row as the new current row.
199 // We don't need to re-initialize the new current row since we will
200 // rewrite all of its cells anyway.
201 int *oldPrevRow = prevRow;
202 prevRow = currRow;
203 currRow = oldPrevRow;
204 }
205
206 // The final result is the value of the last cell in the last row.
207 // Note that that's now in the "previous" row, since we just swapped them.
208 *_result = prevRow[sLen];
209 return SQLITE_OK;
210 }
211
212 // This struct is used only by registerFunctions below, but ISO C++98 forbids
213 // instantiating a template dependent on a locally-defined type. Boo-urns!
214 struct Functions {
215 const char *zName;
216 int nArg;
217 int enc;
218 void *pContext;
219 void (*xFunc)(::sqlite3_context *, int, sqlite3_value **);
220 };
221
222 } // namespace
223
224 ////////////////////////////////////////////////////////////////////////////////
225 //// Exposed Functions
226
registerFunctions(sqlite3 * aDB)227 int registerFunctions(sqlite3 *aDB) {
228 Functions functions[] = {
229 {"lower", 1, SQLITE_UTF16, 0, caseFunction},
230 {"lower", 1, SQLITE_UTF8, 0, caseFunction},
231 {"upper", 1, SQLITE_UTF16, (void *)1, caseFunction},
232 {"upper", 1, SQLITE_UTF8, (void *)1, caseFunction},
233
234 {"like", 2, SQLITE_UTF16, 0, likeFunction},
235 {"like", 2, SQLITE_UTF8, 0, likeFunction},
236 {"like", 3, SQLITE_UTF16, 0, likeFunction},
237 {"like", 3, SQLITE_UTF8, 0, likeFunction},
238
239 {"levenshteinDistance", 2, SQLITE_UTF16, 0, levenshteinDistanceFunction},
240 {"levenshteinDistance", 2, SQLITE_UTF8, 0, levenshteinDistanceFunction},
241 };
242
243 int rv = SQLITE_OK;
244 for (size_t i = 0; SQLITE_OK == rv && i < ArrayLength(functions); ++i) {
245 struct Functions *p = &functions[i];
246 rv = ::sqlite3_create_function(aDB, p->zName, p->nArg, p->enc, p->pContext,
247 p->xFunc, nullptr, nullptr);
248 }
249
250 return rv;
251 }
252
253 ////////////////////////////////////////////////////////////////////////////////
254 //// SQL Functions
255
caseFunction(sqlite3_context * aCtx,int aArgc,sqlite3_value ** aArgv)256 void caseFunction(sqlite3_context *aCtx, int aArgc, sqlite3_value **aArgv) {
257 NS_ASSERTION(1 == aArgc, "Invalid number of arguments!");
258
259 nsAutoString data(
260 static_cast<const char16_t *>(::sqlite3_value_text16(aArgv[0])));
261 bool toUpper = ::sqlite3_user_data(aCtx) ? true : false;
262
263 if (toUpper)
264 ::ToUpperCase(data);
265 else
266 ::ToLowerCase(data);
267
268 // Set the result.
269 ::sqlite3_result_text16(aCtx, data.get(), -1, SQLITE_TRANSIENT);
270 }
271
272 /**
273 * This implements the like() SQL function. This is used by the LIKE operator.
274 * The SQL statement 'A LIKE B' is implemented as 'like(B, A)', and if there is
275 * an escape character, say E, it is implemented as 'like(B, A, E)'.
276 */
likeFunction(sqlite3_context * aCtx,int aArgc,sqlite3_value ** aArgv)277 void likeFunction(sqlite3_context *aCtx, int aArgc, sqlite3_value **aArgv) {
278 NS_ASSERTION(2 == aArgc || 3 == aArgc, "Invalid number of arguments!");
279
280 if (::sqlite3_value_bytes(aArgv[0]) > SQLITE_MAX_LIKE_PATTERN_LENGTH) {
281 ::sqlite3_result_error(aCtx, "LIKE or GLOB pattern too complex",
282 SQLITE_TOOBIG);
283 return;
284 }
285
286 if (!::sqlite3_value_text16(aArgv[0]) || !::sqlite3_value_text16(aArgv[1]))
287 return;
288
289 nsDependentString A(
290 static_cast<const char16_t *>(::sqlite3_value_text16(aArgv[1])));
291 nsDependentString B(
292 static_cast<const char16_t *>(::sqlite3_value_text16(aArgv[0])));
293 NS_ASSERTION(!B.IsEmpty(), "LIKE string must not be null!");
294
295 char16_t E = 0;
296 if (3 == aArgc)
297 E = static_cast<const char16_t *>(::sqlite3_value_text16(aArgv[2]))[0];
298
299 nsAString::const_iterator itrString, endString;
300 A.BeginReading(itrString);
301 A.EndReading(endString);
302 nsAString::const_iterator itrPattern, endPattern;
303 B.BeginReading(itrPattern);
304 B.EndReading(endPattern);
305 ::sqlite3_result_int(
306 aCtx, likeCompare(itrPattern, endPattern, itrString, endString, E));
307 }
308
levenshteinDistanceFunction(sqlite3_context * aCtx,int aArgc,sqlite3_value ** aArgv)309 void levenshteinDistanceFunction(sqlite3_context *aCtx, int aArgc,
310 sqlite3_value **aArgv) {
311 NS_ASSERTION(2 == aArgc, "Invalid number of arguments!");
312
313 // If either argument is a SQL NULL, then return SQL NULL.
314 if (::sqlite3_value_type(aArgv[0]) == SQLITE_NULL ||
315 ::sqlite3_value_type(aArgv[1]) == SQLITE_NULL) {
316 ::sqlite3_result_null(aCtx);
317 return;
318 }
319
320 int aLen = ::sqlite3_value_bytes16(aArgv[0]) / sizeof(char16_t);
321 const char16_t *a =
322 static_cast<const char16_t *>(::sqlite3_value_text16(aArgv[0]));
323
324 int bLen = ::sqlite3_value_bytes16(aArgv[1]) / sizeof(char16_t);
325 const char16_t *b =
326 static_cast<const char16_t *>(::sqlite3_value_text16(aArgv[1]));
327
328 // Compute the Levenshtein Distance, and return the result (or error).
329 int distance = -1;
330 const nsDependentString A(a, aLen);
331 const nsDependentString B(b, bLen);
332 int status = levenshteinDistance(A, B, &distance);
333 if (status == SQLITE_OK) {
334 ::sqlite3_result_int(aCtx, distance);
335 } else if (status == SQLITE_NOMEM) {
336 ::sqlite3_result_error_nomem(aCtx);
337 } else {
338 ::sqlite3_result_error(aCtx, "User function returned error code", -1);
339 }
340 }
341
342 } // namespace storage
343 } // namespace mozilla
344