1# -*- coding: utf-8 -*-
2# Copyright: Ankitects Pty Ltd and contributors
3# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
4
5import copy, re, json
6from anki.utils import intTime, joinFields, splitFields, ids2str,\
7    checksum
8from anki.lang import _
9from anki.consts import *
10from anki.hooks import runHook
11import time
12
13# Models
14##########################################################################
15
16# - careful not to add any lists/dicts/etc here, as they aren't deep copied
17
18defaultModel = {
19    'sortf': 0,
20    'did': 1,
21    'latexPre': """\
22\\documentclass[12pt]{article}
23\\special{papersize=3in,5in}
24\\usepackage[utf8]{inputenc}
25\\usepackage{amssymb,amsmath}
26\\pagestyle{empty}
27\\setlength{\\parindent}{0in}
28\\begin{document}
29""",
30    'latexPost': "\\end{document}",
31    'mod': 0,
32    'usn': 0,
33    'vers': [], # FIXME: remove when other clients have caught up
34    'type': MODEL_STD,
35    'css': """\
36.card {
37 font-family: arial;
38 font-size: 20px;
39 text-align: center;
40 color: black;
41 background-color: white;
42}
43"""
44}
45
46defaultField = {
47    'name': "",
48    'ord': None,
49    'sticky': False,
50    # the following alter editing, and are used as defaults for the
51    # template wizard
52    'rtl': False,
53    'font': "Arial",
54    'size': 20,
55    # reserved for future use
56    'media': [],
57}
58
59defaultTemplate = {
60    'name': "",
61    'ord': None,
62    'qfmt': "",
63    'afmt': "",
64    'did': None,
65    'bqfmt': "",
66    'bafmt': "",
67    # we don't define these so that we pick up system font size until set
68    #'bfont': "Arial",
69    #'bsize': 12,
70}
71
72class ModelManager:
73
74    # Saving/loading registry
75    #############################################################
76
77    def __init__(self, col):
78        self.col = col
79
80    def load(self, json_):
81        "Load registry from JSON."
82        self.changed = False
83        self.models = json.loads(json_)
84
85    def save(self, m=None, templates=False):
86        "Mark M modified if provided, and schedule registry flush."
87        if m and m['id']:
88            m['mod'] = intTime()
89            m['usn'] = self.col.usn()
90            self._updateRequired(m)
91            if templates:
92                self._syncTemplates(m)
93        self.changed = True
94        runHook("newModel")
95
96    def flush(self):
97        "Flush the registry if any models were changed."
98        if self.changed:
99            self.ensureNotEmpty()
100            self.col.db.execute("update col set models = ?",
101                                 json.dumps(self.models))
102            self.changed = False
103
104    def ensureNotEmpty(self):
105        if not self.models:
106            from anki.stdmodels import addBasicModel
107            addBasicModel(self.col)
108            return True
109
110    # Retrieving and creating models
111    #############################################################
112
113    def current(self, forDeck=True):
114        "Get current model."
115        m = self.get(self.col.decks.current().get('mid'))
116        if not forDeck or not m:
117            m = self.get(self.col.conf['curModel'])
118        return m or list(self.models.values())[0]
119
120    def setCurrent(self, m):
121        self.col.conf['curModel'] = m['id']
122        self.col.setMod()
123
124    def get(self, id):
125        "Get model with ID, or None."
126        id = str(id)
127        if id in self.models:
128            return self.models[id]
129
130    def all(self):
131        "Get all models."
132        return list(self.models.values())
133
134    def allNames(self):
135        return [m['name'] for m in self.all()]
136
137    def byName(self, name):
138        "Get model with NAME."
139        for m in list(self.models.values()):
140            if m['name'] == name:
141                return m
142
143    def new(self, name):
144        "Create a new model, save it in the registry, and return it."
145        # caller should call save() after modifying
146        m = defaultModel.copy()
147        m['name'] = name
148        m['mod'] = intTime()
149        m['flds'] = []
150        m['tmpls'] = []
151        m['tags'] = []
152        m['id'] = None
153        return m
154
155    def rem(self, m):
156        "Delete model, and all its cards/notes."
157        self.col.modSchema(check=True)
158        current = self.current()['id'] == m['id']
159        # delete notes/cards
160        self.col.remCards(self.col.db.list("""
161select id from cards where nid in (select id from notes where mid = ?)""",
162                                      m['id']))
163        # then the model
164        del self.models[str(m['id'])]
165        self.save()
166        # GUI should ensure last model is not deleted
167        if current:
168            self.setCurrent(list(self.models.values())[0])
169
170    def add(self, m):
171        self._setID(m)
172        self.update(m)
173        self.setCurrent(m)
174        self.save(m)
175
176    def ensureNameUnique(self, m):
177        for mcur in self.all():
178            if (mcur['name'] == m['name'] and mcur['id'] != m['id']):
179                m['name'] += "-" + checksum(str(time.time()))[:5]
180                break
181
182    def update(self, m):
183        "Add or update an existing model. Used for syncing and merging."
184        self.ensureNameUnique(m)
185        self.models[str(m['id'])] = m
186        # mark registry changed, but don't bump mod time
187        self.save()
188
189    def _setID(self, m):
190        while 1:
191            id = str(intTime(1000))
192            if id not in self.models:
193                break
194        m['id'] = id
195
196    def have(self, id):
197        return str(id) in self.models
198
199    def ids(self):
200        return list(self.models.keys())
201
202    # Tools
203    ##################################################
204
205    def nids(self, m):
206        "Note ids for M."
207        return self.col.db.list(
208            "select id from notes where mid = ?", m['id'])
209
210    def useCount(self, m):
211        "Number of note using M."
212        return self.col.db.scalar(
213            "select count() from notes where mid = ?", m['id'])
214
215    def tmplUseCount(self, m, ord):
216        return self.col.db.scalar("""
217select count() from cards, notes where cards.nid = notes.id
218and notes.mid = ? and cards.ord = ?""", m['id'], ord)
219
220    # Copying
221    ##################################################
222
223    def copy(self, m):
224        "Copy, save and return."
225        m2 = copy.deepcopy(m)
226        m2['name'] = _("%s copy") % m2['name']
227        self.add(m2)
228        return m2
229
230    # Fields
231    ##################################################
232
233    def newField(self, name):
234        assert(isinstance(name, str))
235        f = defaultField.copy()
236        f['name'] = name
237        return f
238
239    def fieldMap(self, m):
240        "Mapping of field name -> (ord, field)."
241        return dict((f['name'], (f['ord'], f)) for f in m['flds'])
242
243    def fieldNames(self, m):
244        return [f['name'] for f in m['flds']]
245
246    def sortIdx(self, m):
247        return m['sortf']
248
249    def setSortIdx(self, m, idx):
250        assert 0 <= idx < len(m['flds'])
251        self.col.modSchema(check=True)
252        m['sortf'] = idx
253        self.col.updateFieldCache(self.nids(m))
254        self.save(m)
255
256    def addField(self, m, field):
257        # only mod schema if model isn't new
258        if m['id']:
259            self.col.modSchema(check=True)
260        m['flds'].append(field)
261        self._updateFieldOrds(m)
262        self.save(m)
263        def add(fields):
264            fields.append("")
265            return fields
266        self._transformFields(m, add)
267
268    def remField(self, m, field):
269        self.col.modSchema(check=True)
270        # save old sort field
271        sortFldName = m['flds'][m['sortf']]['name']
272        idx = m['flds'].index(field)
273        m['flds'].remove(field)
274        # restore old sort field if possible, or revert to first field
275        m['sortf'] = 0
276        for c, f in enumerate(m['flds']):
277            if f['name'] == sortFldName:
278                m['sortf'] = c
279                break
280        self._updateFieldOrds(m)
281        def delete(fields):
282            del fields[idx]
283            return fields
284        self._transformFields(m, delete)
285        if m['flds'][m['sortf']]['name'] != sortFldName:
286            # need to rebuild sort field
287            self.col.updateFieldCache(self.nids(m))
288        # saves
289        self.renameField(m, field, None)
290
291    def moveField(self, m, field, idx):
292        self.col.modSchema(check=True)
293        oldidx = m['flds'].index(field)
294        if oldidx == idx:
295            return
296        # remember old sort field
297        sortf = m['flds'][m['sortf']]
298        # move
299        m['flds'].remove(field)
300        m['flds'].insert(idx, field)
301        # restore sort field
302        m['sortf'] = m['flds'].index(sortf)
303        self._updateFieldOrds(m)
304        self.save(m)
305        def move(fields, oldidx=oldidx):
306            val = fields[oldidx]
307            del fields[oldidx]
308            fields.insert(idx, val)
309            return fields
310        self._transformFields(m, move)
311
312    def renameField(self, m, field, newName):
313        self.col.modSchema(check=True)
314        pat = r'{{([^{}]*)([:#^/]|[^:#/^}][^:}]*?:|)%s}}'
315        def wrap(txt):
316            def repl(match):
317                return '{{' + match.group(1) + match.group(2) + txt +  '}}'
318            return repl
319        for t in m['tmpls']:
320            for fmt in ('qfmt', 'afmt'):
321                if newName:
322                    t[fmt] = re.sub(
323                        pat % re.escape(field['name']), wrap(newName), t[fmt])
324                else:
325                    t[fmt] = re.sub(
326                        pat  % re.escape(field['name']), "", t[fmt])
327        field['name'] = newName
328        self.save(m)
329
330    def _updateFieldOrds(self, m):
331        for c, f in enumerate(m['flds']):
332            f['ord'] = c
333
334    def _transformFields(self, m, fn):
335        # model hasn't been added yet?
336        if not m['id']:
337            return
338        r = []
339        for (id, flds) in self.col.db.execute(
340            "select id, flds from notes where mid = ?", m['id']):
341            r.append((joinFields(fn(splitFields(flds))),
342                      intTime(), self.col.usn(), id))
343        self.col.db.executemany(
344            "update notes set flds=?,mod=?,usn=? where id = ?", r)
345
346    # Templates
347    ##################################################
348
349    def newTemplate(self, name):
350        t = defaultTemplate.copy()
351        t['name'] = name
352        return t
353
354    def addTemplate(self, m, template):
355        "Note: should col.genCards() afterwards."
356        if m['id']:
357            self.col.modSchema(check=True)
358        m['tmpls'].append(template)
359        self._updateTemplOrds(m)
360        self.save(m)
361
362    def remTemplate(self, m, template):
363        "False if removing template would leave orphan notes."
364        assert len(m['tmpls']) > 1
365        # find cards using this template
366        ord = m['tmpls'].index(template)
367        cids = self.col.db.list("""
368select c.id from cards c, notes f where c.nid=f.id and mid = ? and ord = ?""",
369                                 m['id'], ord)
370        # all notes with this template must have at least two cards, or we
371        # could end up creating orphaned notes
372        if self.col.db.scalar("""
373select nid, count() from cards where
374nid in (select nid from cards where id in %s)
375group by nid
376having count() < 2
377limit 1""" % ids2str(cids)):
378            return False
379        # ok to proceed; remove cards
380        self.col.modSchema(check=True)
381        self.col.remCards(cids)
382        # shift ordinals
383        self.col.db.execute("""
384update cards set ord = ord - 1, usn = ?, mod = ?
385 where nid in (select id from notes where mid = ?) and ord > ?""",
386                             self.col.usn(), intTime(), m['id'], ord)
387        m['tmpls'].remove(template)
388        self._updateTemplOrds(m)
389        self.save(m)
390        return True
391
392    def _updateTemplOrds(self, m):
393        for c, t in enumerate(m['tmpls']):
394            t['ord'] = c
395
396    def moveTemplate(self, m, template, idx):
397        oldidx = m['tmpls'].index(template)
398        if oldidx == idx:
399            return
400        oldidxs = dict((id(t), t['ord']) for t in m['tmpls'])
401        m['tmpls'].remove(template)
402        m['tmpls'].insert(idx, template)
403        self._updateTemplOrds(m)
404        # generate change map
405        map = []
406        for t in m['tmpls']:
407            map.append("when ord = %d then %d" % (oldidxs[id(t)], t['ord']))
408        # apply
409        self.save(m)
410        self.col.db.execute("""
411update cards set ord = (case %s end),usn=?,mod=? where nid in (
412select id from notes where mid = ?)""" % " ".join(map),
413                             self.col.usn(), intTime(), m['id'])
414
415    def _syncTemplates(self, m):
416        rem = self.col.genCards(self.nids(m))
417
418    # Model changing
419    ##########################################################################
420    # - maps are ord->ord, and there should not be duplicate targets
421    # - newModel should be self if model is not changing
422
423    def change(self, m, nids, newModel, fmap, cmap):
424        self.col.modSchema(check=True)
425        assert newModel['id'] == m['id'] or (fmap and cmap)
426        if fmap:
427            self._changeNotes(nids, newModel, fmap)
428        if cmap:
429            self._changeCards(nids, m, newModel, cmap)
430        self.col.genCards(nids)
431
432    def _changeNotes(self, nids, newModel, map):
433        d = []
434        nfields = len(newModel['flds'])
435        for (nid, flds) in self.col.db.execute(
436            "select id, flds from notes where id in "+ids2str(nids)):
437            newflds = {}
438            flds = splitFields(flds)
439            for old, new in list(map.items()):
440                newflds[new] = flds[old]
441            flds = []
442            for c in range(nfields):
443                flds.append(newflds.get(c, ""))
444            flds = joinFields(flds)
445            d.append(dict(nid=nid, flds=flds, mid=newModel['id'],
446                      m=intTime(),u=self.col.usn()))
447        self.col.db.executemany(
448            "update notes set flds=:flds,mid=:mid,mod=:m,usn=:u where id = :nid", d)
449        self.col.updateFieldCache(nids)
450
451    def _changeCards(self, nids, oldModel, newModel, map):
452        d = []
453        deleted = []
454        for (cid, ord) in self.col.db.execute(
455            "select id, ord from cards where nid in "+ids2str(nids)):
456            # if the src model is a cloze, we ignore the map, as the gui
457            # doesn't currently support mapping them
458            if oldModel['type'] == MODEL_CLOZE:
459                new = ord
460                if newModel['type'] != MODEL_CLOZE:
461                    # if we're mapping to a regular note, we need to check if
462                    # the destination ord is valid
463                    if len(newModel['tmpls']) <= ord:
464                        new = None
465            else:
466                # mapping from a regular note, so the map should be valid
467                new = map[ord]
468            if new is not None:
469                d.append(dict(
470                    cid=cid,new=new,u=self.col.usn(),m=intTime()))
471            else:
472                deleted.append(cid)
473        self.col.db.executemany(
474            "update cards set ord=:new,usn=:u,mod=:m where id=:cid",
475            d)
476        self.col.remCards(deleted)
477
478    # Schema hash
479    ##########################################################################
480
481    def scmhash(self, m):
482        "Return a hash of the schema, to see if models are compatible."
483        s = ""
484        for f in m['flds']:
485            s += f['name']
486        for t in m['tmpls']:
487            s += t['name']
488        return checksum(s)
489
490    # Required field/text cache
491    ##########################################################################
492
493    def _updateRequired(self, m):
494        if m['type'] == MODEL_CLOZE:
495            # nothing to do
496            return
497        req = []
498        flds = [f['name'] for f in m['flds']]
499        for t in m['tmpls']:
500            ret = self._reqForTemplate(m, flds, t)
501            req.append((t['ord'], ret[0], ret[1]))
502        m['req'] = req
503
504    def _reqForTemplate(self, m, flds, t):
505        a = []
506        b = []
507        for f in flds:
508            a.append("ankiflag")
509            b.append("")
510        data = [1, 1, m['id'], 1, t['ord'], "", joinFields(a), 0]
511        full = self.col._renderQA(data)['q']
512        data = [1, 1, m['id'], 1, t['ord'], "", joinFields(b), 0]
513        empty = self.col._renderQA(data)['q']
514        # if full and empty are the same, the template is invalid and there is
515        # no way to satisfy it
516        if full == empty:
517            return "none", [], []
518        type = 'all'
519        req = []
520        for i in range(len(flds)):
521            tmp = a[:]
522            tmp[i] = ""
523            data[6] = joinFields(tmp)
524            # if no field content appeared, field is required
525            if "ankiflag" not in self.col._renderQA(data)['q']:
526                req.append(i)
527        if req:
528            return type, req
529        # if there are no required fields, switch to any mode
530        type = 'any'
531        req = []
532        for i in range(len(flds)):
533            tmp = b[:]
534            tmp[i] = "1"
535            data[6] = joinFields(tmp)
536            # if not the same as empty, this field can make the card non-blank
537            if self.col._renderQA(data)['q'] != empty:
538                req.append(i)
539        return type, req
540
541    def availOrds(self, m, flds):
542        "Given a joined field string, return available template ordinals."
543        if m['type'] == MODEL_CLOZE:
544            return self._availClozeOrds(m, flds)
545        fields = {}
546        for c, f in enumerate(splitFields(flds)):
547            fields[c] = f.strip()
548        avail = []
549        for ord, type, req in m['req']:
550            # unsatisfiable template
551            if type == "none":
552                continue
553            # AND requirement?
554            elif type == "all":
555                ok = True
556                for idx in req:
557                    if not fields[idx]:
558                        # missing and was required
559                        ok = False
560                        break
561                if not ok:
562                    continue
563            # OR requirement?
564            elif type == "any":
565                ok = False
566                for idx in req:
567                    if fields[idx]:
568                        ok = True
569                        break
570                if not ok:
571                    continue
572            avail.append(ord)
573        return avail
574
575    def _availClozeOrds(self, m, flds, allowEmpty=True):
576        sflds = splitFields(flds)
577        map = self.fieldMap(m)
578        ords = set()
579        matches = re.findall("{{[^}]*?cloze:(?:[^}]?:)*(.+?)}}", m['tmpls'][0]['qfmt'])
580        matches += re.findall("<%cloze:(.+?)%>", m['tmpls'][0]['qfmt'])
581        for fname in matches:
582            if fname not in map:
583                continue
584            ord = map[fname][0]
585            ords.update([int(m)-1 for m in re.findall(
586                r"(?s){{c(\d+)::.+?}}", sflds[ord])])
587        if -1 in ords:
588            ords.remove(-1)
589        if not ords and allowEmpty:
590            # empty clozes use first ord
591            return [0]
592        return list(ords)
593
594    # Sync handling
595    ##########################################################################
596
597    def beforeUpload(self):
598        for m in self.all():
599            m['usn'] = 0
600        self.save()
601