1# Copyright 2012 Matt Chaput. All rights reserved.
2#
3# Redistribution and use in source and binary forms, with or without
4# modification, are permitted provided that the following conditions are met:
5#
6#    1. Redistributions of source code must retain the above copyright notice,
7#       this list of conditions and the following disclaimer.
8#
9#    2. Redistributions in binary form must reproduce the above copyright
10#       notice, this list of conditions and the following disclaimer in the
11#       documentation and/or other materials provided with the distribution.
12#
13# THIS SOFTWARE IS PROVIDED BY MATT CHAPUT ``AS IS'' AND ANY EXPRESS OR
14# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
15# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
16# EVENT SHALL MATT CHAPUT OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
17# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
18# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
19# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
20# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
21# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
22# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23#
24# The views and conclusions contained in the software and documentation are
25# those of the authors and should not be interpreted as representing official
26# policies, either expressed or implied, of Matt Chaput.
27
28from ast import literal_eval
29
30from whoosh.compat import b, bytes_type, text_type, integer_types, PY3
31from whoosh.compat import iteritems, dumps, loads, xrange
32from whoosh.codec import base
33from whoosh.matching import ListMatcher
34from whoosh.reading import TermInfo, TermNotFound
35
36if not PY3:
37    class memoryview:
38        pass
39
40_reprable = (bytes_type, text_type, integer_types, float)
41
42
43# Mixin classes for producing and consuming the simple text format
44
45class LineWriter(object):
46    def _print_line(self, indent, command, **kwargs):
47        self._dbfile.write(b("  ") * indent)
48        self._dbfile.write(command.encode("latin1"))
49        for k, v in iteritems(kwargs):
50            if isinstance(v, memoryview):
51                v = bytes(v)
52            if v is not None and not isinstance(v, _reprable):
53                raise TypeError(type(v))
54            self._dbfile.write(("\t%s=%r" % (k, v)).encode("latin1"))
55        self._dbfile.write(b("\n"))
56
57
58class LineReader(object):
59    def __init__(self, dbfile):
60        self._dbfile = dbfile
61
62    def _reset(self):
63        self._dbfile.seek(0)
64
65    def _find_line(self, indent, command, **kwargs):
66        for largs in self._find_lines(indent, command, **kwargs):
67            return largs
68
69    def _find_lines(self, indent, command, **kwargs):
70        while True:
71            line = self._dbfile.readline()
72            if not line:
73                return
74
75            c = self._parse_line(line)
76            if c is None:
77                return
78
79            lindent, lcommand, largs = c
80            if lindent == indent and lcommand == command:
81                matched = True
82                if kwargs:
83                    for k in kwargs:
84                        if kwargs[k] != largs.get(k):
85                            matched = False
86                            break
87
88                if matched:
89                    yield largs
90            elif lindent < indent:
91                return
92
93    def _parse_line(self, line):
94        line = line.decode("latin1")
95        line = line.rstrip()
96        l = len(line)
97        line = line.lstrip()
98        if not line or line.startswith("#"):
99            return None
100
101        indent = (l - len(line)) // 2
102
103        parts = line.split("\t")
104        command = parts[0]
105        args = {}
106        for i in xrange(1, len(parts)):
107            n, v = parts[i].split("=")
108            args[n] = literal_eval(v)
109        return (indent, command, args)
110
111    def _find_root(self, command):
112        self._reset()
113        c = self._find_line(0, command)
114        if c is None:
115            raise Exception("No root section %r" % (command,))
116
117
118# Codec class
119
120class PlainTextCodec(base.Codec):
121    length_stats = False
122
123    def per_document_writer(self, storage, segment):
124        return PlainPerDocWriter(storage, segment)
125
126    def field_writer(self, storage, segment):
127        return PlainFieldWriter(storage, segment)
128
129    def per_document_reader(self, storage, segment):
130        return PlainPerDocReader(storage, segment)
131
132    def terms_reader(self, storage, segment):
133        return PlainTermsReader(storage, segment)
134
135    def new_segment(self, storage, indexname):
136        return PlainSegment(indexname)
137
138
139class PlainPerDocWriter(base.PerDocumentWriter, LineWriter):
140    def __init__(self, storage, segment):
141        self._dbfile = storage.create_file(segment.make_filename(".dcs"))
142        self._print_line(0, "DOCS")
143        self.is_closed = False
144
145    def start_doc(self, docnum):
146        self._print_line(1, "DOC", dn=docnum)
147
148    def add_field(self, fieldname, fieldobj, value, length):
149        if value is not None:
150            value = dumps(value, 2)
151        self._print_line(2, "DOCFIELD", fn=fieldname, v=value, len=length)
152
153    def add_column_value(self, fieldname, columnobj, value):
154        self._print_line(2, "COLVAL", fn=fieldname, v=value)
155
156    def add_vector_items(self, fieldname, fieldobj, items):
157        self._print_line(2, "VECTOR", fn=fieldname)
158        for text, weight, vbytes in items:
159            self._print_line(3, "VPOST", t=text, w=weight, v=vbytes)
160
161    def finish_doc(self):
162        pass
163
164    def close(self):
165        self._dbfile.close()
166        self.is_closed = True
167
168
169class PlainPerDocReader(base.PerDocumentReader, LineReader):
170    def __init__(self, storage, segment):
171        self._dbfile = storage.open_file(segment.make_filename(".dcs"))
172        self._segment = segment
173        self.is_closed = False
174
175    def doc_count(self):
176        return self._segment.doc_count()
177
178    def doc_count_all(self):
179        return self._segment.doc_count()
180
181    def has_deletions(self):
182        return False
183
184    def is_deleted(self, docnum):
185        return False
186
187    def deleted_docs(self):
188        return frozenset()
189
190    def _find_doc(self, docnum):
191        self._find_root("DOCS")
192        c = self._find_line(1, "DOC")
193        while c is not None:
194            dn = c["dn"]
195            if dn == docnum:
196                return True
197            elif dn > docnum:
198                return False
199            c = self._find_line(1, "DOC")
200        return False
201
202    def _iter_docs(self):
203        self._find_root("DOCS")
204        c = self._find_line(1, "DOC")
205        while c is not None:
206            yield c["dn"]
207            c = self._find_line(1, "DOC")
208
209    def _iter_docfields(self, fieldname):
210        for _ in self._iter_docs():
211            for c in self._find_lines(2, "DOCFIELD", fn=fieldname):
212                yield c
213
214    def _iter_lengths(self, fieldname):
215        return (c.get("len", 0) for c in self._iter_docfields(fieldname))
216
217    def doc_field_length(self, docnum, fieldname, default=0):
218        for dn in self._iter_docs():
219            if dn == docnum:
220
221                c = self._find_line(2, "DOCFIELD", fn=fieldname)
222                if c is not None:
223                    return c.get("len", default)
224            elif dn > docnum:
225                break
226
227        return default
228
229    def _column_values(self, fieldname):
230        for i, docnum in enumerate(self._iter_docs()):
231            if i != docnum:
232                raise Exception("Missing column value for field %r doc %d?"
233                                % (fieldname, i))
234
235            c = self._find_line(2, "COLVAL", fn=fieldname)
236            if c is None:
237                raise Exception("Missing column value for field %r doc %d?"
238                                % (fieldname, docnum))
239
240            yield c.get("v")
241
242    def has_column(self, fieldname):
243        for _ in self._column_values(fieldname):
244            return True
245        return False
246
247    def column_reader(self, fieldname, column):
248        return list(self._column_values(fieldname))
249
250    def field_length(self, fieldname):
251        return sum(self._iter_lengths(fieldname))
252
253    def min_field_length(self, fieldname):
254        return min(self._iter_lengths(fieldname))
255
256    def max_field_length(self, fieldname):
257        return max(self._iter_lengths(fieldname))
258
259    def has_vector(self, docnum, fieldname):
260        if self._find_doc(docnum):
261            if self._find_line(2, "VECTOR"):
262                return True
263        return False
264
265    def vector(self, docnum, fieldname, format_):
266        if not self._find_doc(docnum):
267            raise Exception
268        if not self._find_line(2, "VECTOR"):
269            raise Exception
270
271        ids = []
272        weights = []
273        values = []
274        c = self._find_line(3, "VPOST")
275        while c is not None:
276            ids.append(c["t"])
277            weights.append(c["w"])
278            values.append(c["v"])
279            c = self._find_line(3, "VPOST")
280
281        return ListMatcher(ids, weights, values, format_,)
282
283    def _read_stored_fields(self):
284        sfs = {}
285        c = self._find_line(2, "DOCFIELD")
286        while c is not None:
287            v = c.get("v")
288            if v is not None:
289                v = loads(v)
290            sfs[c["fn"]] = v
291            c = self._find_line(2, "DOCFIELD")
292        return sfs
293
294    def stored_fields(self, docnum):
295        if not self._find_doc(docnum):
296            raise Exception
297        return self._read_stored_fields()
298
299    def iter_docs(self):
300        return enumerate(self.all_stored_fields())
301
302    def all_stored_fields(self):
303        for _ in self._iter_docs():
304            yield self._read_stored_fields()
305
306    def close(self):
307        self._dbfile.close()
308        self.is_closed = True
309
310
311class PlainFieldWriter(base.FieldWriter, LineWriter):
312    def __init__(self, storage, segment):
313        self._dbfile = storage.create_file(segment.make_filename(".trm"))
314        self._print_line(0, "TERMS")
315
316    @property
317    def is_closed(self):
318        return self._dbfile.is_closed
319
320    def start_field(self, fieldname, fieldobj):
321        self._fieldobj = fieldobj
322        self._print_line(1, "TERMFIELD", fn=fieldname)
323
324    def start_term(self, btext):
325        self._terminfo = TermInfo()
326        self._print_line(2, "BTEXT", t=btext)
327
328    def add(self, docnum, weight, vbytes, length):
329        self._terminfo.add_posting(docnum, weight, length)
330        self._print_line(3, "POST", dn=docnum, w=weight, v=vbytes)
331
332    def finish_term(self):
333        ti = self._terminfo
334        self._print_line(3, "TERMINFO",
335                         df=ti.doc_frequency(), weight=ti.weight(),
336                         minlength=ti.min_length(), maxlength=ti.max_length(),
337                         maxweight=ti.max_weight(),
338                         minid=ti.min_id(), maxid=ti.max_id())
339
340    def add_spell_word(self, fieldname, text):
341        self._print_line(2, "SPELL", fn=fieldname, t=text)
342
343    def close(self):
344        self._dbfile.close()
345
346
347class PlainTermsReader(base.TermsReader, LineReader):
348    def __init__(self, storage, segment):
349        self._dbfile = storage.open_file(segment.make_filename(".trm"))
350        self._segment = segment
351        self.is_closed = False
352
353    def _find_field(self, fieldname):
354        self._find_root("TERMS")
355        if self._find_line(1, "TERMFIELD", fn=fieldname) is None:
356            raise TermNotFound("No field %r" % fieldname)
357
358    def _iter_fields(self):
359        self._find_root()
360        c = self._find_line(1, "TERMFIELD")
361        while c is not None:
362            yield c["fn"]
363            c = self._find_line(1, "TERMFIELD")
364
365    def _iter_btexts(self):
366        c = self._find_line(2, "BTEXT")
367        while c is not None:
368            yield c["t"]
369            c = self._find_line(2, "BTEXT")
370
371    def _find_term(self, fieldname, btext):
372        self._find_field(fieldname)
373        for t in self._iter_btexts():
374            if t == btext:
375                return True
376            elif t > btext:
377                break
378        return False
379
380    def _find_terminfo(self):
381        c = self._find_line(3, "TERMINFO")
382        return TermInfo(**c)
383
384    def __contains__(self, term):
385        fieldname, btext = term
386        return self._find_term(fieldname, btext)
387
388    def indexed_field_names(self):
389        return self._iter_fields()
390
391    def terms(self):
392        for fieldname in self._iter_fields():
393            for btext in self._iter_btexts():
394                yield (fieldname, btext)
395
396    def terms_from(self, fieldname, prefix):
397        self._find_field(fieldname)
398        for btext in self._iter_btexts():
399            if btext < prefix:
400                continue
401            yield (fieldname, btext)
402
403    def items(self):
404        for fieldname, btext in self.terms():
405            yield (fieldname, btext), self._find_terminfo()
406
407    def items_from(self, fieldname, prefix):
408        for fieldname, btext in self.terms_from(fieldname, prefix):
409            yield (fieldname, btext), self._find_terminfo()
410
411    def term_info(self, fieldname, btext):
412        if not self._find_term(fieldname, btext):
413            raise TermNotFound((fieldname, btext))
414        return self._find_terminfo()
415
416    def matcher(self, fieldname, btext, format_, scorer=None):
417        if not self._find_term(fieldname, btext):
418            raise TermNotFound((fieldname, btext))
419
420        ids = []
421        weights = []
422        values = []
423        c = self._find_line(3, "POST")
424        while c is not None:
425            ids.append(c["dn"])
426            weights.append(c["w"])
427            values.append(c["v"])
428            c = self._find_line(3, "POST")
429
430        return ListMatcher(ids, weights, values, format_, scorer=scorer)
431
432    def close(self):
433        self._dbfile.close()
434        self.is_closed = True
435
436
437class PlainSegment(base.Segment):
438    def __init__(self, indexname):
439        base.Segment.__init__(self, indexname)
440        self._doccount = 0
441
442    def codec(self):
443        return PlainTextCodec()
444
445    def set_doc_count(self, doccount):
446        self._doccount = doccount
447
448    def doc_count(self):
449        return self._doccount
450
451    def should_assemble(self):
452        return False
453