1 /*
2  * sqlite_extension.cpp
3  * Copyright (C) 2021 Kovid Goyal <kovid at kovidgoyal.net>
4  *
5  * Distributed under terms of the GPL3 license.
6  */
7 
8 #define PY_SSIZE_T_CLEAN
9 #define UNICODE
10 #include <Python.h>
11 #include <stdlib.h>
12 #include <string>
13 #include <locale>
14 #include <vector>
15 #include <unordered_map>
16 #include <mutex>
17 #include <cstring>
18 #include <sqlite3ext.h>
19 #include <unicode/unistr.h>
20 #include <unicode/uchar.h>
21 #include <unicode/translit.h>
22 #include <unicode/errorcode.h>
23 #include <unicode/brkiter.h>
24 #include <unicode/uscript.h>
25 #if __has_include(<libstemmer.h>)
26 #include <libstemmer.h>
27 #else
28 #include <libstemmer/libstemmer.h>
29 #endif
30 #include "../utils/cpp_binding.h"
31 SQLITE_EXTENSION_INIT1
32 
33 typedef int (*token_callback_func)(void *, int, const char *, int, int, int);
34 
35 
36 // Converting SQLITE text to ICU strings {{{
37 // UTF-8 decode taken from: https://bjoern.hoehrmann.de/utf-8/decoder/dfa/
38 
39 static const uint8_t utf8_data[] = {
40   0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // 00..1f
41   0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // 20..3f
42   0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // 40..5f
43   0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // 60..7f
44   1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, // 80..9f
45   7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, // a0..bf
46   8,8,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, // c0..df
47   0xa,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x4,0x3,0x3, // e0..ef
48   0xb,0x6,0x6,0x6,0x5,0x8,0x8,0x8,0x8,0x8,0x8,0x8,0x8,0x8,0x8,0x8, // f0..ff
49   0x0,0x1,0x2,0x3,0x5,0x8,0x7,0x1,0x1,0x1,0x4,0x6,0x1,0x1,0x1,0x1, // s0..s0
50   1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,1,1,1,1,1,0,1,0,1,1,1,1,1,1, // s1..s2
51   1,2,1,1,1,1,1,2,1,2,1,1,1,1,1,1,1,1,1,1,1,1,1,2,1,1,1,1,1,1,1,1, // s3..s4
52   1,2,1,1,1,1,1,1,1,2,1,1,1,1,1,1,1,1,1,1,1,1,1,3,1,3,1,1,1,1,1,1, // s5..s6
53   1,3,1,1,1,1,1,3,1,3,1,1,1,1,1,1,1,3,1,1,1,1,1,1,1,1,1,1,1,1,1,1, // s7..s8
54 };
55 
56 
57 typedef enum UTF8State { UTF8_ACCEPT = 0, UTF8_REJECT = 1} UTF8State;
58 
59 uint32_t
decode_utf8(UTF8State * state,uint32_t * codep,uint8_t byte)60 decode_utf8(UTF8State* state, uint32_t* codep, uint8_t byte) {
61   uint32_t type = utf8_data[byte];
62 
63   *codep = (*state != UTF8_ACCEPT) ?
64     (byte & 0x3fu) | (*codep << 6) :
65     (0xff >> type) & (byte);
66 
67   *state = (UTF8State) utf8_data[256 + *state*16 + type];
68   return *state;
69 }
70 
71 
72 static void
populate_icu_string(const char * text,int text_sz,icu::UnicodeString & str,std::vector<int> & byte_offsets)73 populate_icu_string(const char *text, int text_sz, icu::UnicodeString &str, std::vector<int> &byte_offsets) {
74     UTF8State state = UTF8_ACCEPT, prev = UTF8_ACCEPT;
75     uint32_t codep = 0;
76     for (int i = 0, pos = 0; i < text_sz; i++) {
77         switch(decode_utf8(&state, &codep, text[i])) {
78             case UTF8_ACCEPT: {
79                 size_t sz = str.length();
80                 str.append((UChar32)codep);
81                 sz = str.length() - sz;
82                 for (size_t x = 0; x < sz; x++) byte_offsets.push_back(pos);
83                 pos = i + 1;
84             }
85                 break;
86             case UTF8_REJECT:
87                 state = UTF8_ACCEPT;
88                 if (prev != UTF8_ACCEPT && i > 0) i--;
89                 break;
90         }
91         prev = state;
92     }
93     byte_offsets.push_back(text_sz);
94 }
95 // }}}
96 
97 static char ui_language[16] = {'e', 'n', 0};
98 static std::mutex global_mutex;
99 
100 class IteratorDescription {
101     public:
102         const char *language;
103         UScriptCode script;
104 };
105 
106 class Stemmer {
107 private:
108     struct sb_stemmer *handle;
109     char lang_name[32];
110 
111 public:
Stemmer()112     Stemmer() : handle(NULL), lang_name() {}
Stemmer(const char * lang)113     Stemmer(const char *lang) {
114         size_t len = strlen(lang);
115         for (size_t i = 0; i < sizeof(lang_name) - 1 && i < len; i++) {
116             lang_name[i] = lang[i];
117             if ('A' <= lang_name[i] && lang_name[i] <= 'Z') lang_name[i] += 'a' - 'A';
118         }
119         lang_name[std::min(len, sizeof(lang_name) - 1)] = 0;
120         handle = sb_stemmer_new(lang_name, NULL);
121     }
language_name() const122     const char* language_name() const { return lang_name; }
123 
~Stemmer()124     ~Stemmer() {
125         if (handle) {
126             sb_stemmer_delete(handle);
127             handle = NULL;
128         }
129     }
130 
stem(const char * token,size_t token_sz,int * sz)131     const char* stem(const char *token, size_t token_sz, int *sz) {
132         const char *ans = NULL;
133         if (handle) {
134             ans = reinterpret_cast<const char*>(sb_stemmer_stem(handle, reinterpret_cast<const sb_symbol*>(token), (int)token_sz));
135             if (ans) *sz = sb_stemmer_length(handle);
136         }
137         return ans;
138     }
139 
operator bool() const140     explicit operator bool() const noexcept { return handle != NULL; }
141 };
142 
143 typedef std::unique_ptr<icu::BreakIterator> BreakIterator;
144 typedef std::unique_ptr<Stemmer> StemmerPtr;
145 static const std::string empty_string("");
146 
147 class Tokenizer {
148 private:
149     bool remove_diacritics, stem_words;
150     std::unique_ptr<icu::Transliterator> diacritics_remover;
151     std::vector<int> byte_offsets;
152     std::string token_buf, current_ui_language;
153     token_callback_func current_callback;
154     void *current_callback_ctx;
155     std::unordered_map<std::string, BreakIterator> iterators;
156     std::unordered_map<std::string, StemmerPtr> stemmers;
157 
is_token_char(UChar32 ch) const158     bool is_token_char(UChar32 ch) const {
159         switch(u_charType(ch)) {
160             case U_UPPERCASE_LETTER:
161             case U_LOWERCASE_LETTER:
162             case U_TITLECASE_LETTER:
163             case U_MODIFIER_LETTER:
164             case U_OTHER_LETTER:
165             case U_DECIMAL_DIGIT_NUMBER:
166             case U_LETTER_NUMBER:
167             case U_OTHER_NUMBER:
168             case U_CURRENCY_SYMBOL:
169             case U_OTHER_SYMBOL:
170             case U_PRIVATE_USE_CHAR:
171                 return true;
172             default:
173                 break;;
174         }
175         return false;
176     }
177 
send_token(const icu::UnicodeString & token,int32_t start_offset,int32_t end_offset,StemmerPtr & stemmer,int flags=0)178     int send_token(const icu::UnicodeString &token, int32_t start_offset, int32_t end_offset, StemmerPtr &stemmer, int flags = 0) {
179         token_buf.clear(); token_buf.reserve(4 * token.length());
180         token.toUTF8String(token_buf);
181         const char *root = token_buf.c_str(); int sz = (int)token_buf.size();
182         if (stem_words && stemmer->operator bool()) {
183             root = stemmer->stem(root, sz, &sz);
184             if (!root) {
185                 root = token_buf.c_str();
186                 sz = (int)token_buf.size();
187             }
188         }
189         return current_callback(current_callback_ctx, flags, root, (int)sz, byte_offsets.at(start_offset), byte_offsets.at(end_offset));
190     }
191 
iterator_language_for_script(UScriptCode script) const192     const char* iterator_language_for_script(UScriptCode script) const {
193         switch (script) {
194             default:
195                 return "";
196             case USCRIPT_THAI:
197             case USCRIPT_LAO:
198                 return "th_TH";
199             case USCRIPT_KHMER:
200                 return "km_KH";
201             case USCRIPT_MYANMAR:
202                 return "my_MM";
203             case USCRIPT_HIRAGANA:
204             case USCRIPT_KATAKANA:
205                 return "ja_JP";
206             case USCRIPT_HANGUL:
207                 return "ko_KR";
208             case USCRIPT_HAN:
209             case USCRIPT_SIMPLIFIED_HAN:
210             case USCRIPT_TRADITIONAL_HAN:
211             case USCRIPT_HAN_WITH_BOPOMOFO:
212                 return "zh";
213         }
214     }
215 
at_script_boundary(IteratorDescription & current,UChar32 next_codepoint) const216     bool at_script_boundary(IteratorDescription &current, UChar32 next_codepoint) const {
217         icu::ErrorCode err;
218         UScriptCode script = uscript_getScript(next_codepoint, err);
219         if (script == USCRIPT_COMMON || script == USCRIPT_INVALID_CODE || script == USCRIPT_INHERITED || current.script == script) return false;
220         const char *lang = iterator_language_for_script(script);
221         if (strcmp(current.language, lang) == 0) return false;
222         current.script = script; current.language = lang;
223         return true;
224     }
225 
ensure_basic_iterator(void)226     void ensure_basic_iterator(void) {
227         std::lock_guard<std::mutex> lock(global_mutex);
228         if (current_ui_language != ui_language || iterators.find(empty_string) == iterators.end()) {
229             current_ui_language.clear(); current_ui_language = ui_language;
230             icu::ErrorCode status;
231             if (current_ui_language.empty()) {
232                 iterators[empty_string] = BreakIterator(icu::BreakIterator::createWordInstance(icu::Locale::getDefault(), status));
233             } else {
234                 ensure_lang_iterator(ui_language);
235             }
236         }
237     }
238 
ensure_lang_iterator(const char * lang="")239     BreakIterator& ensure_lang_iterator(const char *lang = "") {
240         auto ans = iterators.find(lang);
241         if (ans == iterators.end()) {
242             icu::ErrorCode status;
243             iterators[lang] = BreakIterator(icu::BreakIterator::createWordInstance(icu::Locale::createCanonical(lang), status));
244             if (status.isFailure()) {
245                 iterators[lang] = BreakIterator(icu::BreakIterator::createWordInstance(icu::Locale::getDefault(), status));
246             }
247             ans = iterators.find(lang);
248         }
249         return ans->second;
250     }
251 
ensure_stemmer(const char * lang="")252     StemmerPtr& ensure_stemmer(const char *lang = "") {
253         if (!lang[0]) lang = current_ui_language.c_str();
254         auto ans = stemmers.find(lang);
255         if (ans == stemmers.end()) {
256             stemmers[lang] = stem_words ? std::make_unique<Stemmer>(lang) : std::make_unique<Stemmer>();
257             ans = stemmers.find(lang);
258         }
259         return ans->second;
260     }
261 
tokenize_script_block(const icu::UnicodeString & str,int32_t block_start,int32_t block_limit,bool for_query,token_callback_func callback,void * callback_ctx,BreakIterator & word_iterator,StemmerPtr & stemmer)262     int tokenize_script_block(const icu::UnicodeString &str, int32_t block_start, int32_t block_limit, bool for_query, token_callback_func callback, void *callback_ctx, BreakIterator &word_iterator, StemmerPtr &stemmer) {
263         word_iterator->setText(str.tempSubStringBetween(block_start, block_limit));
264         int32_t token_start_pos = word_iterator->first() + block_start, token_end_pos;
265         int rc = SQLITE_OK;
266         do {
267             token_end_pos = word_iterator->next();
268             if (token_end_pos == icu::BreakIterator::DONE) token_end_pos = block_limit;
269             else token_end_pos += block_start;
270             if (token_end_pos > token_start_pos) {
271                 bool is_token = false;
272                 for (int32_t pos = token_start_pos; !is_token && pos < token_end_pos; pos = str.moveIndex32(pos, 1)) {
273                     if (is_token_char(str.char32At(pos))) is_token = true;
274                 }
275                 if (is_token) {
276                     icu::UnicodeString token(str, token_start_pos, token_end_pos - token_start_pos);
277                     token.foldCase();
278                     if ((rc = send_token(token, token_start_pos, token_end_pos, stemmer)) != SQLITE_OK) return rc;
279                     if (!for_query && remove_diacritics) {
280                         icu::UnicodeString tt(str, token_start_pos, token_end_pos - token_start_pos);
281                         diacritics_remover->transliterate(tt);
282                         tt.foldCase();
283                         if (tt != token) {
284                             if ((rc = send_token(tt, token_start_pos, token_end_pos, stemmer, FTS5_TOKEN_COLOCATED)) != SQLITE_OK) return rc;
285                         }
286                     }
287                 }
288             }
289             token_start_pos = token_end_pos;
290         } while (token_end_pos < block_limit);
291         return rc;
292     }
293 
294 public:
295     int constructor_error;
296 
Tokenizer(const char ** args,int nargs,bool stem_words=false)297     Tokenizer(const char **args, int nargs, bool stem_words = false) :
298         remove_diacritics(true), stem_words(stem_words), diacritics_remover(),
299         byte_offsets(), token_buf(), current_ui_language(""),
300         current_callback(NULL), current_callback_ctx(NULL),
301         iterators(), stemmers(),
302 
303         constructor_error(SQLITE_OK)
304     {
305         for (int i = 0; i < nargs; i++) {
306             if (strcmp(args[i], "remove_diacritics") == 0) {
307                 i++;
308                 if (i < nargs && strcmp(args[i], "0") == 0) remove_diacritics = false;
309             }
310             else if (strcmp(args[i], "stem_words") == 0) {
311                 i++;
312                 if (i < nargs && strcmp(args[i], "0") == 0) stem_words = false;
313                 else stem_words = true;
314             }
315         }
316         if (remove_diacritics) {
317             icu::ErrorCode status;
318             diacritics_remover.reset(icu::Transliterator::createInstance("NFD; [:M:] Remove; NFC", UTRANS_FORWARD, status));
319             if (status.isFailure()) {
320                 fprintf(stderr, "Failed to create ICU transliterator to remove diacritics with error: %s\n", status.errorName());
321                 constructor_error = SQLITE_INTERNAL;
322                 diacritics_remover.reset(NULL);
323                 remove_diacritics = false;
324             }
325         }
326         std::lock_guard<std::mutex> lock(global_mutex);
327         current_ui_language = ui_language;
328     }
329 
tokenize(void * callback_ctx,int flags,const char * text,int text_sz,token_callback_func callback)330     int tokenize(void *callback_ctx, int flags, const char *text, int text_sz, token_callback_func callback) {
331         ensure_basic_iterator();
332         current_callback = callback; current_callback_ctx = callback_ctx;
333         icu::UnicodeString str(text_sz, 0, 0);
334         byte_offsets.clear();
335         byte_offsets.reserve(text_sz + 8);
336         populate_icu_string(text, text_sz, str, byte_offsets);
337         int32_t offset = str.getChar32Start(0);
338         int rc = SQLITE_OK;
339         bool for_query = (flags & FTS5_TOKENIZE_QUERY) != 0;
340         IteratorDescription state;
341         state.language = ""; state.script = USCRIPT_COMMON;
342         int32_t start_script_block_at = offset;
343         auto word_iterator = std::ref(ensure_lang_iterator(state.language));
344         auto stemmer = std::ref(ensure_stemmer(state.language));
345         while (offset < str.length()) {
346             UChar32 ch = str.char32At(offset);
347             if (at_script_boundary(state, ch)) {
348                 if (offset > start_script_block_at) {
349                     if ((rc = tokenize_script_block(
350                         str, start_script_block_at, offset,
351                         for_query, callback, callback_ctx, word_iterator, stemmer)) != SQLITE_OK) return rc;
352                 }
353                 start_script_block_at = offset;
354                 word_iterator = ensure_lang_iterator(state.language);
355                 stemmer = ensure_stemmer(state.language);
356             }
357             offset = str.moveIndex32(offset, 1);
358         }
359         if (offset > start_script_block_at) {
360             rc = tokenize_script_block(str, start_script_block_at, offset, for_query, callback, callback_ctx, word_iterator, stemmer);
361         }
362         return rc;
363     }
364 };
365 
366 // boilerplate {{{
367 static int
fts5_api_from_db(sqlite3 * db,fts5_api ** ppApi)368 fts5_api_from_db(sqlite3 *db, fts5_api **ppApi) {
369     sqlite3_stmt *pStmt = 0;
370     *ppApi = 0;
371     int rc = sqlite3_prepare(db, "SELECT fts5(?1)", -1, &pStmt, 0);
372     if (rc == SQLITE_OK) {
373         sqlite3_bind_pointer(pStmt, 1, reinterpret_cast<void *>(ppApi), "fts5_api_ptr", 0);
374         (void)sqlite3_step(pStmt);
375         rc = sqlite3_finalize(pStmt);
376     }
377     return rc;
378 }
379 
380 static int
_tok_create(void * sqlite3,const char ** azArg,int nArg,Fts5Tokenizer ** ppOut,bool stem_words=false)381 _tok_create(void *sqlite3, const char **azArg, int nArg, Fts5Tokenizer **ppOut, bool stem_words = false) {
382     int rc = SQLITE_OK;
383     try {
384         Tokenizer *p = new Tokenizer(azArg, nArg, stem_words);
385         if (p->constructor_error != SQLITE_OK)  {
386             rc = p->constructor_error;
387             delete p;
388         } else {
389             *ppOut = reinterpret_cast<Fts5Tokenizer *>(p);
390         }
391     } catch (std::bad_alloc const&) {
392         return SQLITE_NOMEM;
393     } catch (...) {
394         return SQLITE_ERROR;
395     }
396     return rc;
397 }
398 
399 static int
tok_create(void * sqlite3,const char ** azArg,int nArg,Fts5Tokenizer ** ppOut)400 tok_create(void *sqlite3, const char **azArg, int nArg, Fts5Tokenizer **ppOut) { return _tok_create(sqlite3, azArg, nArg, ppOut); }
401 
402 static int
tok_create_with_stemming(void * sqlite3,const char ** azArg,int nArg,Fts5Tokenizer ** ppOut)403 tok_create_with_stemming(void *sqlite3, const char **azArg, int nArg, Fts5Tokenizer **ppOut) { return _tok_create(sqlite3, azArg, nArg, ppOut, true); }
404 
405 static int
tok_tokenize(Fts5Tokenizer * tokenizer_ptr,void * callback_ctx,int flags,const char * text,int text_sz,token_callback_func callback)406 tok_tokenize(Fts5Tokenizer *tokenizer_ptr, void *callback_ctx, int flags, const char *text, int text_sz, token_callback_func callback) {
407     Tokenizer *p = reinterpret_cast<Tokenizer*>(tokenizer_ptr);
408     try {
409         return p->tokenize(callback_ctx, flags, text, text_sz, callback);
410     } catch (std::bad_alloc const&) {
411         return SQLITE_NOMEM;
412     } catch (...) {
413         return SQLITE_ERROR;
414     }
415 
416 }
417 
418 static void
tok_delete(Fts5Tokenizer * p)419 tok_delete(Fts5Tokenizer *p) {
420     Tokenizer *t = reinterpret_cast<Tokenizer*>(p);
421     delete t;
422 }
423 
424 extern "C" {
425 #ifdef _MSC_VER
426 #define MYEXPORT __declspec(dllexport)
427 #else
428 #define MYEXPORT __attribute__ ((visibility ("default")))
429 #endif
430 
431 MYEXPORT int
calibre_sqlite_extension_init(sqlite3 * db,char ** pzErrMsg,const sqlite3_api_routines * pApi)432 calibre_sqlite_extension_init(sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi){
433     SQLITE_EXTENSION_INIT2(pApi);
434     fts5_api *fts5api = NULL;
435     int rc = fts5_api_from_db(db, &fts5api);
436     if (rc != SQLITE_OK) {
437         *pzErrMsg = (char*)"Failed to get FTS 5 API with error code";
438         return rc;
439     }
440     if (!fts5api || fts5api->iVersion < 2) {
441         *pzErrMsg = (char*)"FTS 5 iVersion too old or NULL pointer";
442         return SQLITE_ERROR;
443     }
444     fts5_tokenizer tok = {tok_create, tok_delete, tok_tokenize};
445     fts5api->xCreateTokenizer(fts5api, "unicode61", reinterpret_cast<void *>(fts5api), &tok, NULL);
446     fts5api->xCreateTokenizer(fts5api, "calibre", reinterpret_cast<void *>(fts5api), &tok, NULL);
447     fts5_tokenizer tok2 = {tok_create_with_stemming, tok_delete, tok_tokenize};
448     fts5api->xCreateTokenizer(fts5api, "porter", reinterpret_cast<void *>(fts5api), &tok2, NULL);
449     return SQLITE_OK;
450 }
451 }
452 
453 static PyObject*
get_locales_for_break_iteration(PyObject * self,PyObject * args)454 get_locales_for_break_iteration(PyObject *self, PyObject *args) {
455     std::unique_ptr<icu::StringEnumeration> locs(icu::BreakIterator::getAvailableLocales());
456     icu::ErrorCode status;
457     pyobject_raii ans(PyList_New(0));
458     if (ans) {
459         const icu::UnicodeString *item;
460         while ((item = locs->snext(status))) {
461             std::string name;
462             item->toUTF8String(name);
463             pyobject_raii pn(PyUnicode_FromString(name.c_str()));
464             if (pn) PyList_Append(ans.ptr(), pn.ptr());
465         }
466         if (status.isFailure()) {
467             PyErr_Format(PyExc_RuntimeError, "Failed to iterate over locales with error: %s", status.errorName());
468             return NULL;
469         }
470     }
471     return ans.detach();
472 }
473 
474 static PyObject*
set_ui_language(PyObject * self,PyObject * args)475 set_ui_language(PyObject *self, PyObject *args) {
476     std::lock_guard<std::mutex> lock(global_mutex);
477     const char *val;
478     if (!PyArg_ParseTuple(args, "s", &val)) return NULL;
479     strncpy(ui_language, val, sizeof(ui_language) - 1);
480     Py_RETURN_NONE;
481 }
482 
483 static int
py_callback(void * ctx,int flags,const char * text,int text_length,int start_offset,int end_offset)484 py_callback(void *ctx, int flags, const char *text, int text_length, int start_offset, int end_offset) {
485     PyObject *ans = reinterpret_cast<PyObject*>(ctx);
486     Py_ssize_t psz = text_length;
487     pyobject_raii item(Py_BuildValue("{ss# si si si}", "text", text, psz, "start", start_offset, "end", end_offset, "flags", flags));
488     if (item) PyList_Append(ans, item.ptr());
489     return SQLITE_OK;
490 }
491 
492 static PyObject*
tokenize(PyObject * self,PyObject * args)493 tokenize(PyObject *self, PyObject *args) {
494     const char *text; Py_ssize_t text_length, remove_diacritics = 1, flags = FTS5_TOKENIZE_DOCUMENT;
495     if (!PyArg_ParseTuple(args, "s#|pi", &text, &text_length, &remove_diacritics, &flags)) return NULL;
496     const char *targs[2] = {"remove_diacritics", "2"};
497     if (!remove_diacritics) targs[1] = "0";
498     Tokenizer t(targs, sizeof(targs)/sizeof(targs[0]));
499     pyobject_raii ans(PyList_New(0));
500     if (!ans) return NULL;
501     t.tokenize(ans.ptr(), flags, text, text_length, py_callback);
502     return ans.detach();
503 }
504 
505 static PyObject*
stem(PyObject * self,PyObject * args)506 stem(PyObject *self, PyObject *args) {
507     const char *text, *lang = "en"; Py_ssize_t text_length;
508     if (!PyArg_ParseTuple(args, "s#|s", &text, &text_length, &lang)) return NULL;
509     Stemmer s(lang);
510     if (!s) {
511         PyErr_SetString(PyExc_ValueError, "No stemmer for the specified language");
512         return NULL;
513     }
514     int sz;
515     const char* result = s.stem(text, text_length, &sz);
516     Py_ssize_t a = sz;
517     if (!result) return PyErr_NoMemory();
518     return Py_BuildValue("s#", result, a);
519 }
520 
521 static PyMethodDef methods[] = {
522     {"get_locales_for_break_iteration", get_locales_for_break_iteration, METH_NOARGS,
523      "Get list of available locales for break iteration"
524     },
525     {"set_ui_language", set_ui_language, METH_VARARGS,
526      "Set the current UI language"
527     },
528     {"tokenize", tokenize, METH_VARARGS,
529      "Tokenize a string, useful for testing"
530     },
531     {"stem", stem, METH_VARARGS,
532      "Stem a word in the specified language, defaulting to English"
533     },
534     {NULL, NULL, 0, NULL}
535 };
536 
537 static int
exec_module(PyObject * mod)538 exec_module(PyObject *mod) {
539     if (PyModule_AddIntMacro(mod, FTS5_TOKENIZE_QUERY) != 0) return 1;
540     if (PyModule_AddIntMacro(mod, FTS5_TOKENIZE_DOCUMENT) != 0) return 1;
541     if (PyModule_AddIntMacro(mod, FTS5_TOKENIZE_PREFIX) != 0) return 1;
542     if (PyModule_AddIntMacro(mod, FTS5_TOKENIZE_AUX) != 0) return 1;
543     if (PyModule_AddIntMacro(mod, FTS5_TOKEN_COLOCATED) != 0) return 1;
544     return 0;
545 }
546 
547 static PyModuleDef_Slot slots[] = { {Py_mod_exec, (void*)exec_module}, {0, NULL} };
548 
549 static struct PyModuleDef module_def = {PyModuleDef_HEAD_INIT};
550 
551 extern "C" {
PyInit_sqlite_extension(void)552 CALIBRE_MODINIT_FUNC PyInit_sqlite_extension(void) {
553     module_def.m_name     = "sqlite_extension";
554     module_def.m_doc      = "Implement ICU based tokenizer for FTS5";
555     module_def.m_methods  = methods;
556     module_def.m_slots    = slots;
557     return PyModuleDef_Init(&module_def);
558 }
559 } // }}}
560