1# -*- coding: utf-8 -*-
2# Copyright: Ankitects Pty Ltd and contributors
3# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
4
5import re
6import sre_constants
7import unicodedata
8
9from anki.utils import ids2str, splitFields, joinFields, intTime, fieldChecksum, stripHTMLMedia
10from anki.consts import *
11from anki.hooks import *
12
13
14# Find
15##########################################################################
16
17class Finder:
18
19    def __init__(self, col):
20        self.col = col
21        self.search = dict(
22            added=self._findAdded,
23            card=self._findTemplate,
24            deck=self._findDeck,
25            mid=self._findMid,
26            nid=self._findNids,
27            cid=self._findCids,
28            note=self._findModel,
29            prop=self._findProp,
30            rated=self._findRated,
31            tag=self._findTag,
32            dupe=self._findDupes,
33            flag=self._findFlag,
34        )
35        self.search['is'] = self._findCardState
36        runHook("search", self.search)
37
38    def findCards(self, query, order=False):
39        "Return a list of card ids for QUERY."
40        tokens = self._tokenize(query)
41        preds, args = self._where(tokens)
42        if preds is None:
43            raise Exception("invalidSearch")
44        order, rev = self._order(order)
45        sql = self._query(preds, order)
46        try:
47            res = self.col.db.list(sql, *args)
48        except:
49            # invalid grouping
50            return []
51        if rev:
52            res.reverse()
53        return res
54
55    def findNotes(self, query):
56        tokens = self._tokenize(query)
57        preds, args = self._where(tokens)
58        if preds is None:
59            return []
60        if preds:
61            preds = "(" + preds + ")"
62        else:
63            preds = "1"
64        sql = """
65select distinct(n.id) from cards c, notes n where c.nid=n.id and """+preds
66        try:
67            res = self.col.db.list(sql, *args)
68        except:
69            # invalid grouping
70            return []
71        return res
72
73    # Tokenizing
74    ######################################################################
75
76    def _tokenize(self, query):
77        inQuote = False
78        tokens = []
79        token = ""
80        for c in query:
81            # quoted text
82            if c in ("'", '"'):
83                if inQuote:
84                    if c == inQuote:
85                        inQuote = False
86                    else:
87                        token += c
88                elif token:
89                    # quotes are allowed to start directly after a :
90                    if token[-1] == ":":
91                        inQuote = c
92                    else:
93                        token += c
94                else:
95                    inQuote = c
96            # separator (space and ideographic space)
97            elif c in (" ", '\u3000'):
98                if inQuote:
99                    token += c
100                elif token:
101                    # space marks token finished
102                    tokens.append(token)
103                    token = ""
104            # nesting
105            elif c in ("(", ")"):
106                if inQuote:
107                    token += c
108                else:
109                    if c == ")" and token:
110                        tokens.append(token)
111                        token = ""
112                    tokens.append(c)
113            # negation
114            elif c == "-":
115                if token:
116                    token += c
117                elif not tokens or tokens[-1] != "-":
118                    tokens.append("-")
119            # normal character
120            else:
121                token += c
122        # if we finished in a token, add it
123        if token:
124            tokens.append(token)
125        return tokens
126
127    # Query building
128    ######################################################################
129
130    def _where(self, tokens):
131        # state and query
132        s = dict(isnot=False, isor=False, join=False, q="", bad=False)
133        args = []
134        def add(txt, wrap=True):
135            # failed command?
136            if not txt:
137                # if it was to be negated then we can just ignore it
138                if s['isnot']:
139                    s['isnot'] = False
140                    return
141                else:
142                    s['bad'] = True
143                    return
144            elif txt == "skip":
145                return
146            # do we need a conjunction?
147            if s['join']:
148                if s['isor']:
149                    s['q'] += " or "
150                    s['isor'] = False
151                else:
152                    s['q'] += " and "
153            if s['isnot']:
154                s['q'] += " not "
155                s['isnot'] = False
156            if wrap:
157                txt = "(" + txt + ")"
158            s['q'] += txt
159            s['join'] = True
160        for token in tokens:
161            if s['bad']:
162                return None, None
163            # special tokens
164            if token == "-":
165                s['isnot'] = True
166            elif token.lower() == "or":
167                s['isor'] = True
168            elif token == "(":
169                add(token, wrap=False)
170                s['join'] = False
171            elif token == ")":
172                s['q'] += ")"
173            # commands
174            elif ":" in token:
175                cmd, val = token.split(":", 1)
176                cmd = cmd.lower()
177                if cmd in self.search:
178                    add(self.search[cmd]((val, args)))
179                else:
180                    add(self._findField(cmd, val))
181            # normal text search
182            else:
183                add(self._findText(token, args))
184        if s['bad']:
185            return None, None
186        return s['q'], args
187
188    def _query(self, preds, order):
189        # can we skip the note table?
190        if "n." not in preds and "n." not in order:
191            sql = "select c.id from cards c where "
192        else:
193            sql = "select c.id from cards c, notes n where c.nid=n.id and "
194        # combine with preds
195        if preds:
196            sql += "(" + preds + ")"
197        else:
198            sql += "1"
199        # order
200        if order:
201            sql += " " + order
202        return sql
203
204    # Ordering
205    ######################################################################
206
207    def _order(self, order):
208        if not order:
209            return "", False
210        elif order is not True:
211            # custom order string provided
212            return " order by " + order, False
213        # use deck default
214        type = self.col.conf['sortType']
215        sort = None
216        if type.startswith("note"):
217            if type == "noteCrt":
218                sort = "n.id, c.ord"
219            elif type == "noteMod":
220                sort = "n.mod, c.ord"
221            elif type == "noteFld":
222                sort = "n.sfld collate nocase, c.ord"
223        elif type.startswith("card"):
224            if type == "cardMod":
225                sort = "c.mod"
226            elif type == "cardReps":
227                sort = "c.reps"
228            elif type == "cardDue":
229                sort = "c.type, c.due"
230            elif type == "cardEase":
231                sort = "c.type == 0, c.factor"
232            elif type == "cardLapses":
233                sort = "c.lapses"
234            elif type == "cardIvl":
235                sort = "c.ivl"
236        if not sort:
237            # deck has invalid sort order; revert to noteCrt
238            sort = "n.id, c.ord"
239        return " order by " + sort, self.col.conf['sortBackwards']
240
241    # Commands
242    ######################################################################
243
244    def _findTag(self, args):
245        (val, args) = args
246        if val == "none":
247            return 'n.tags = ""'
248        val = val.replace("*", "%")
249        if not val.startswith("%"):
250            val = "% " + val
251        if not val.endswith("%") or val.endswith('\\%'):
252            val += " %"
253        args.append(val)
254        return "n.tags like ? escape '\\'"
255
256    def _findCardState(self, args):
257        (val, args) = args
258        if val in ("review", "new", "learn"):
259            if val == "review":
260                n = 2
261            elif val == "new":
262                n = 0
263            else:
264                return "queue in (1, 3)"
265            return "type = %d" % n
266        elif val == "suspended":
267            return "c.queue = -1"
268        elif val == "buried":
269            return "c.queue in (-2, -3)"
270        elif val == "due":
271            return """
272(c.queue in (2,3) and c.due <= %d) or
273(c.queue = 1 and c.due <= %d)""" % (
274    self.col.sched.today, self.col.sched.dayCutoff)
275
276    def _findFlag(self, args):
277        (val, args) = args
278        if not val or len(val)!=1 or val not in "01234":
279            return
280        val = int(val)
281        mask = 2**3 - 1
282        return "(c.flags & %d) == %d" % (mask, val)
283
284    def _findRated(self, args):
285        # days(:optional_ease)
286        (val, args) = args
287        r = val.split(":")
288        try:
289            days = int(r[0])
290        except ValueError:
291            return
292        days = min(days, 31)
293        # ease
294        ease = ""
295        if len(r) > 1:
296            if r[1] not in ("1", "2", "3", "4"):
297                return
298            ease = "and ease=%s" % r[1]
299        cutoff = (self.col.sched.dayCutoff - 86400*days)*1000
300        return ("c.id in (select cid from revlog where id>%d %s)" %
301                (cutoff, ease))
302
303    def _findAdded(self, args):
304        (val, args) = args
305        try:
306            days = int(val)
307        except ValueError:
308            return
309        cutoff = (self.col.sched.dayCutoff - 86400*days)*1000
310        return "c.id > %d" % cutoff
311
312    def _findProp(self, args):
313        # extract
314        (val, args) = args
315        m = re.match("(^.+?)(<=|>=|!=|=|<|>)(.+?$)", val)
316        if not m:
317            return
318        prop, cmp, val = m.groups()
319        prop = prop.lower()
320        # is val valid?
321        try:
322            if prop == "ease":
323                val = float(val)
324            else:
325                val = int(val)
326        except ValueError:
327            return
328        # is prop valid?
329        if prop not in ("due", "ivl", "reps", "lapses", "ease"):
330            return
331        # query
332        q = []
333        if prop == "due":
334            val += self.col.sched.today
335            # only valid for review/daily learning
336            q.append("(c.queue in (2,3))")
337        elif prop == "ease":
338            prop = "factor"
339            val = int(val*1000)
340        q.append("(%s %s %s)" % (prop, cmp, val))
341        return " and ".join(q)
342
343    def _findText(self, val, args):
344        val = val.replace("*", "%")
345        args.append("%"+val+"%")
346        args.append("%"+val+"%")
347        return "(n.sfld like ? escape '\\' or n.flds like ? escape '\\')"
348
349    def _findNids(self, args):
350        (val, args) = args
351        if re.search("[^0-9,]", val):
352            return
353        return "n.id in (%s)" % val
354
355    def _findCids(self, args):
356        (val, args) = args
357        if re.search("[^0-9,]", val):
358            return
359        return "c.id in (%s)" % val
360
361    def _findMid(self, args):
362        (val, args) = args
363        if re.search("[^0-9]", val):
364            return
365        return "n.mid = %s" % val
366
367    def _findModel(self, args):
368        (val, args) = args
369        ids = []
370        val = val.lower()
371        for m in self.col.models.all():
372            if unicodedata.normalize("NFC", m['name'].lower()) == val:
373                ids.append(m['id'])
374        return "n.mid in %s" % ids2str(ids)
375
376    def _findDeck(self, args):
377        # if searching for all decks, skip
378        (val, args) = args
379        if val == "*":
380            return "skip"
381        # deck types
382        elif val == "filtered":
383            return "c.odid"
384        def dids(did):
385            if not did:
386                return None
387            return [did] + [a[1] for a in self.col.decks.children(did)]
388        # current deck?
389        ids = None
390        if val.lower() == "current":
391            ids = dids(self.col.decks.current()['id'])
392        elif "*" not in val:
393            # single deck
394            ids = dids(self.col.decks.id(val, create=False))
395        else:
396            # wildcard
397            ids = set()
398            val = re.escape(val).replace(r"\*", ".*")
399            for d in self.col.decks.all():
400                if re.match("(?i)"+val, unicodedata.normalize("NFC", d['name'])):
401                    ids.update(dids(d['id']))
402        if not ids:
403            return
404        sids = ids2str(ids)
405        return "c.did in %s or c.odid in %s" % (sids, sids)
406
407    def _findTemplate(self, args):
408        # were we given an ordinal number?
409        (val, args) = args
410        try:
411            num = int(val) - 1
412        except:
413            num = None
414        if num is not None:
415            return "c.ord = %d" % num
416        # search for template names
417        lims = []
418        for m in self.col.models.all():
419            for t in m['tmpls']:
420                if unicodedata.normalize("NFC", t['name'].lower()) == val.lower():
421                    if m['type'] == MODEL_CLOZE:
422                        # if the user has asked for a cloze card, we want
423                        # to give all ordinals, so we just limit to the
424                        # model instead
425                        lims.append("(n.mid = %s)" % m['id'])
426                    else:
427                        lims.append("(n.mid = %s and c.ord = %s)" % (
428                            m['id'], t['ord']))
429        return " or ".join(lims)
430
431    def _findField(self, field, val):
432        field = field.lower()
433        val = val.replace("*", "%")
434        # find models that have that field
435        mods = {}
436        for m in self.col.models.all():
437            for f in m['flds']:
438                if unicodedata.normalize("NFC", f['name'].lower()) == field:
439                    mods[str(m['id'])] = (m, f['ord'])
440        if not mods:
441            # nothing has that field
442            return
443        # gather nids
444        regex = re.escape(val).replace("_", ".").replace(re.escape("%"), ".*")
445        nids = []
446        for (id,mid,flds) in self.col.db.execute("""
447select id, mid, flds from notes
448where mid in %s and flds like ? escape '\\'""" % (
449                         ids2str(list(mods.keys()))),
450                         "%"+val+"%"):
451            flds = splitFields(flds)
452            ord = mods[str(mid)][1]
453            strg = flds[ord]
454            try:
455                if re.search("(?si)^"+regex+"$", strg):
456                    nids.append(id)
457            except sre_constants.error:
458                return
459        if not nids:
460            return "0"
461        return "n.id in %s" % ids2str(nids)
462
463    def _findDupes(self, args):
464        # caller must call stripHTMLMedia on passed val
465        (val, args) = args
466        try:
467            mid, val = val.split(",", 1)
468        except OSError:
469            return
470        csum = fieldChecksum(val)
471        nids = []
472        for nid, flds in self.col.db.execute(
473                "select id, flds from notes where mid=? and csum=?",
474                mid, csum):
475            if stripHTMLMedia(splitFields(flds)[0]) == val:
476                nids.append(nid)
477        return "n.id in %s" % ids2str(nids)
478
479# Find and replace
480##########################################################################
481
482def findReplace(col, nids, src, dst, regex=False, field=None, fold=True):
483    "Find and replace fields in a note."
484    mmap = {}
485    if field:
486        for m in col.models.all():
487            for f in m['flds']:
488                if f['name'].lower() == field.lower():
489                    mmap[str(m['id'])] = f['ord']
490        if not mmap:
491            return 0
492    # find and gather replacements
493    if not regex:
494        src = re.escape(src)
495        dst = dst.replace("\\", "\\\\")
496    if fold:
497        src = "(?i)"+src
498    regex = re.compile(src)
499    def repl(str):
500        return re.sub(regex, dst, str)
501    d = []
502    snids = ids2str(nids)
503    nids = []
504    for nid, mid, flds in col.db.execute(
505        "select id, mid, flds from notes where id in "+snids):
506        origFlds = flds
507        # does it match?
508        sflds = splitFields(flds)
509        if field:
510            try:
511                ord = mmap[str(mid)]
512                sflds[ord] = repl(sflds[ord])
513            except KeyError:
514                # note doesn't have that field
515                continue
516        else:
517            for c in range(len(sflds)):
518                sflds[c] = repl(sflds[c])
519        flds = joinFields(sflds)
520        if flds != origFlds:
521            nids.append(nid)
522            d.append(dict(nid=nid,flds=flds,u=col.usn(),m=intTime()))
523    if not d:
524        return 0
525    # replace
526    col.db.executemany(
527        "update notes set flds=:flds,mod=:m,usn=:u where id=:nid", d)
528    col.updateFieldCache(nids)
529    col.genCards(nids)
530    return len(d)
531
532def fieldNames(col, downcase=True):
533    fields = set()
534    for m in col.models.all():
535        for f in m['flds']:
536            name=f['name'].lower() if downcase else f['name']
537            if name not in fields: #slower w/o
538                fields.add(name)
539    return list(fields)
540
541def fieldNamesForNotes(col, nids):
542    fields = set()
543    mids = col.db.list("select distinct mid from notes where id in %s" % ids2str(nids))
544    for mid in mids:
545        model = col.models.get(mid)
546        for name in col.models.fieldNames(model):
547            if name not in fields: #slower w/o
548                fields.add(name)
549    return sorted(fields, key=lambda x: x.lower())
550
551# Find duplicates
552##########################################################################
553# returns array of ("dupestr", [nids])
554def findDupes(col, fieldName, search=""):
555    # limit search to notes with applicable field name
556    if search:
557        search = "("+search+") "
558    search += "'%s:*'" % fieldName
559    # go through notes
560    vals = {}
561    dupes = []
562    fields = {}
563    def ordForMid(mid):
564        if mid not in fields:
565            model = col.models.get(mid)
566            for c, f in enumerate(model['flds']):
567                if f['name'].lower() == fieldName.lower():
568                    fields[mid] = c
569                    break
570        return fields[mid]
571    for nid, mid, flds in col.db.all(
572        "select id, mid, flds from notes where id in "+ids2str(
573            col.findNotes(search))):
574        flds = splitFields(flds)
575        ord = ordForMid(mid)
576        if ord is None:
577            continue
578        val = flds[ord]
579        val = stripHTMLMedia(val)
580        # empty does not count as duplicate
581        if not val:
582            continue
583        if val not in vals:
584            vals[val] = []
585        vals[val].append(nid)
586        if len(vals[val]) == 2:
587            dupes.append((val, vals[val]))
588    return dupes
589