1#cython: embedsignature=True
2#cython: language_level=3
3#cython: binding=True
4
5"""A fast C implementation of the Acora search engine.
6
7There are two main classes, UnicodeAcora and BytesAcora, that handle
8byte data and unicode data respectively.
9"""
10
11__all__ = ['BytesAcora', 'UnicodeAcora']
12
13cimport cython
14cimport cpython.exc
15cimport cpython.mem
16cimport cpython.bytes
17from cpython.ref cimport PyObject
18from cpython.unicode cimport PyUnicode_AS_UNICODE, PyUnicode_GET_SIZE
19
20from ._acora cimport (
21    _Machine, _MachineState, build_MachineState, _find_child,
22    _convert_old_format, merge_targets, _make_printable)
23
24cdef extern from * nogil:
25    ssize_t read(int fd, void *buf, size_t count)
26
27cdef extern from "acora_defs.h":
28    # PEP 393
29    cdef bint PyUnicode_IS_READY(object u)
30    cdef Py_ssize_t PyUnicode_GET_LENGTH(object u)
31    cdef int PyUnicode_KIND(object u)
32    cdef void* PyUnicode_DATA(object u)
33    cdef int PyUnicode_WCHAR_KIND
34    cdef Py_UCS4 PyUnicode_READ(int kind, void* data, Py_ssize_t index) nogil
35
36
37DEF FILE_BUFFER_SIZE = 32 * 1024
38
39ctypedef struct _AcoraUnicodeNodeStruct:
40    Py_UCS4* characters
41    _AcoraUnicodeNodeStruct** targets
42    PyObject** matches
43    int char_count
44
45ctypedef struct _AcoraBytesNodeStruct:
46    unsigned char* characters
47    _AcoraBytesNodeStruct** targets
48    PyObject** matches
49    int char_count
50
51
52# state machine building support
53
54def insert_bytes_keyword(_MachineState tree, keyword, long state_id, bint ignore_case=False):
55    # keep in sync with insert_unicode_keyword()
56    cdef _MachineState child
57    cdef unsigned char ch
58    if not isinstance(keyword, bytes):
59        raise TypeError("expected bytes object, got %s" % type(keyword).__name__)
60    if not <bytes>keyword:
61        raise ValueError("cannot search for the empty string")
62    #print(keyword)
63    for ch in <bytes>keyword:
64        if ignore_case:
65            if c'A' <= ch <= c'Z':
66                ch += c'a' - c'A'
67        #print(ch)
68        if tree.children is None:
69            tree.children = []
70            child = None
71        else:
72            child = _find_child(tree, ch)
73        if child is None:
74            child = build_MachineState(state_id)
75            child.letter = ch
76            state_id += 1
77            tree.children.append(child)
78        tree = child
79    if ignore_case and tree.matches:
80        if keyword not in tree.matches:
81            tree.matches.append(keyword)
82    else:
83        tree.matches = [keyword]
84    return state_id
85
86
87def insert_unicode_keyword(_MachineState tree, keyword, long state_id, bint ignore_case=False):
88    # keep in sync with insert_bytes_keyword()
89    cdef _MachineState child
90    cdef Py_UCS4 ch
91    if not isinstance(keyword, unicode):
92        raise TypeError("expected Unicode string, got %s" % type(keyword).__name__)
93    if not <unicode>keyword:
94        raise ValueError("cannot search for the empty string")
95    for ch in <unicode>keyword:
96        if ignore_case:
97            ch = ch.lower()
98        if tree.children is None:
99            tree.children = []
100            child = None
101        else:
102            child = _find_child(tree, ch)
103        if child is None:
104            child = build_MachineState(state_id)
105            child.letter = ch
106            state_id += 1
107            tree.children.append(child)
108        tree = child
109    if ignore_case and tree.matches:
110        if keyword not in tree.matches:
111            tree.matches.append(keyword)
112    else:
113        tree.matches = [keyword]
114    return state_id
115
116
117def machine_to_dot(machine, out=None):
118    cdef _AcoraUnicodeNodeStruct* unode
119    cdef _AcoraUnicodeNodeStruct* unodes = NULL
120    cdef _AcoraBytesNodeStruct* bnode
121    cdef _AcoraBytesNodeStruct* bnodes = NULL
122    cdef Py_ssize_t node_count, node_id
123    cdef PyObject **cmatches
124    cdef Py_UCS4 ch
125    cdef unsigned char bch
126
127    if isinstance(machine, UnicodeAcora):
128        unodes = (<UnicodeAcora>machine).start_node
129        node_count = (<UnicodeAcora>machine).node_count
130    elif isinstance(machine, BytesAcora):
131        bnodes = (<BytesAcora>machine).start_node
132        node_count = (<BytesAcora>machine).node_count
133    else:
134        raise TypeError(
135            "Expected UnicodeAcora or BytesAcora instance, got %s" % machine.__class__.__name__)
136
137    if out is None:
138        from sys import stdout as out
139
140    write = out.write
141    write("digraph {\n")
142    write('%s [label="%s"];\n' % (0, 'start'))
143    seen = set()
144    for node_id in range(node_count):
145        if unodes:
146            unode = unodes + node_id
147            characters = [ch for ch in unode.characters[:unode.char_count]]
148            child_ids = [<size_t>(child - unodes) for child in unode.targets[:unode.char_count]]
149            cmatches = unode.matches
150        else:
151            bnode = bnodes + node_id
152            characters = [<bytes>bch for bch in bnode.characters[:bnode.char_count]]
153            child_ids = [<size_t>(child - bnodes) for child in bnode.targets[:bnode.char_count]]
154            cmatches = bnode.matches
155
156        if cmatches is not NULL:
157            matches = []
158            while cmatches[0]:
159                matches.append(_make_printable(<object>cmatches[0]))
160                cmatches += 1
161            if matches:
162                write('M%s [label="%s", shape=note];\n' % (
163                    node_id, '\\n'.join(_make_printable(s) for s in matches)))
164                write('%s -> M%s [style=dotted];\n' % (node_id, node_id))
165
166        for child_id, character in zip(child_ids, characters):
167            character = _make_printable(character)
168            if child_id not in seen:
169                write('%s [label="%s"];\n' % (child_id, character))
170                seen.add(child_id)
171            write('%s -> %s [label="%s"];\n' % (node_id, child_id, character))
172    write("}\n")
173
174
175# Unicode machine
176
177cdef int _init_unicode_node(
178        _AcoraUnicodeNodeStruct* c_node, _MachineState state,
179        _AcoraUnicodeNodeStruct* all_nodes,
180        dict node_offsets, dict pyrefs, bint ignore_case) except -1:
181    cdef _MachineState child, fail_state
182    cdef size_t mem_size
183    cdef Py_ssize_t i
184    cdef unicode letter
185    cdef dict targets
186
187    # merge children failure states and matches to avoid deep failure state traversal
188    targets, matches = merge_targets(state, ignore_case)
189    cdef size_t child_count = len(targets)
190
191    # use a single malloc for targets and match-string pointers
192    mem_size = sizeof(_AcoraUnicodeNodeStruct**) * child_count
193    if matches:
194        mem_size += sizeof(PyObject*) * (len(matches) + 1)  # NULL terminated
195    mem_size += sizeof(Py_UCS4) * child_count
196    c_node.targets = <_AcoraUnicodeNodeStruct**> cpython.mem.PyMem_Malloc(mem_size)
197    if c_node.targets is NULL:
198        raise MemoryError()
199
200    if not matches:
201        c_node.matches = NULL
202        c_characters = <Py_UCS4*> (c_node.targets + child_count)
203    else:
204        c_node.matches = <PyObject**> (c_node.targets + child_count)
205        matches = _intern(pyrefs, tuple(matches))
206        i = 0
207        for match in matches:
208            c_node.matches[i] = <PyObject*>match
209            i += 1
210        c_node.matches[i] = NULL
211        c_characters = <Py_UCS4*> (c_node.matches + i + 1)
212
213    if state.children and len(targets) == len(state.children):
214        for i, child in enumerate(state.children):
215            c_node.targets[i] = all_nodes + <size_t>node_offsets[child]
216            c_characters[i] = child.letter
217    else:
218        # dict[key] is much faster than creating and sorting item tuples
219        for i, character in enumerate(sorted(targets)):
220            c_node.targets[i] = all_nodes + <size_t>node_offsets[targets[character]]
221            c_characters[i] = character
222
223    c_node.characters = c_characters
224    c_node.char_count = child_count
225
226
227cdef int _init_bytes_node(
228        _AcoraBytesNodeStruct* c_node, state,
229        _AcoraBytesNodeStruct* all_nodes,
230        dict node_offsets, dict pyrefs, bint ignore_case) except -1:
231    cdef _MachineState child, fail_state
232    cdef size_t mem_size
233    cdef Py_ssize_t i
234    cdef unicode letter
235    cdef dict targets
236
237    # merge children failure states and matches to avoid deep failure state traversal
238    targets, matches = merge_targets(state, ignore_case)
239    cdef size_t child_count = len(targets)
240
241    # use a single malloc for targets and match-string pointers
242    mem_size = targets_mem_size = sizeof(_AcoraBytesNodeStruct**) * len(targets)
243    if matches:
244        mem_size += sizeof(PyObject*) * (len(matches) + 1) # NULL terminated
245    c_node.targets = <_AcoraBytesNodeStruct**> cpython.mem.PyMem_Malloc(mem_size)
246    if c_node.targets is NULL:
247        raise MemoryError()
248
249    if mem_size == targets_mem_size:  # no matches
250        c_node.matches = NULL
251    else:
252        c_node.matches = <PyObject**> (c_node.targets + len(targets))
253        matches = _intern(pyrefs, tuple(matches))
254        i = 0
255        for match in matches:
256            c_node.matches[i] = <PyObject*>match
257            i += 1
258        c_node.matches[i] = NULL
259
260    characters = cpython.bytes.PyBytes_FromStringAndSize(NULL, len(targets))
261    cdef unsigned char *c_characters = characters
262    if len(targets) == len(state.children):
263        for i, child in enumerate(state.children):
264            c_node.targets[i] = all_nodes + <size_t>node_offsets[child]
265            c_characters[i] = child.letter
266    else:
267        # dict[key] is much faster than creating and sorting item tuples
268        for i, character in enumerate(sorted(targets)):
269            c_node.targets[i] = all_nodes + <size_t>node_offsets[targets[character]]
270            c_characters[i] = <Py_UCS4>character
271    characters = _intern(pyrefs, characters)
272
273    c_node.characters = characters
274    c_node.char_count = len(characters)
275
276
277cdef inline _intern(dict d, obj):
278    if obj in d:
279        return d[obj]
280    d[obj] = obj
281    return obj
282
283
284cdef dict group_transitions_by_state(dict transitions):
285    transitions_by_state = {}
286    for (state, character), target in transitions.iteritems():
287        if state in transitions_by_state:
288            transitions_by_state[state].append((character, target))
289        else:
290            transitions_by_state[state] = [(character, target)]
291    return transitions_by_state
292
293
294# unicode data handling
295
296cdef class UnicodeAcora:
297    """Acora search engine for unicode data.
298    """
299    cdef _AcoraUnicodeNodeStruct* start_node
300    cdef Py_ssize_t node_count
301    cdef tuple _pyrefs
302    cdef bint _ignore_case
303
304    def __cinit__(self, start_state, dict transitions=None):
305        cdef _Machine machine
306        cdef _AcoraUnicodeNodeStruct* c_nodes
307        cdef _AcoraUnicodeNodeStruct* c_node
308        cdef Py_ssize_t i
309
310        if transitions is not None:
311            # old pickle format => rebuild trie
312            machine = _convert_old_format(transitions)
313        else:
314            machine = start_state
315        ignore_case = self._ignore_case = machine.ignore_case
316        self.node_count = len(machine.child_states) + 1
317
318        c_nodes = self.start_node = <_AcoraUnicodeNodeStruct*> cpython.mem.PyMem_Malloc(
319            sizeof(_AcoraUnicodeNodeStruct) * self.node_count)
320        if c_nodes is NULL:
321            raise MemoryError()
322
323        for c_node in c_nodes[:self.node_count]:
324            # required by __dealloc__ in case of subsequent errors
325            c_node.targets = NULL
326
327        node_offsets = {state: i for i, state in enumerate(machine.child_states, 1)}
328        node_offsets[machine.start_state] = 0
329        pyrefs = {}  # used to keep Python references alive (and intern them)
330
331        _init_unicode_node(c_nodes, machine.start_state, c_nodes, node_offsets, pyrefs, ignore_case)
332        for i, state in enumerate(machine.child_states, 1):
333            _init_unicode_node(c_nodes + i, state, c_nodes, node_offsets, pyrefs, ignore_case)
334        self._pyrefs = tuple(pyrefs)
335
336    def __dealloc__(self):
337        cdef Py_ssize_t i
338        if self.start_node is not NULL:
339            for i in range(self.node_count):
340                if self.start_node[i].targets is not NULL:
341                    cpython.mem.PyMem_Free(self.start_node[i].targets)
342            cpython.mem.PyMem_Free(self.start_node)
343
344    def __reduce__(self):
345        """pickle"""
346        cdef _AcoraUnicodeNodeStruct* c_node
347        cdef _AcoraUnicodeNodeStruct* c_child
348        cdef _AcoraUnicodeNodeStruct* c_start_node = self.start_node
349        cdef Py_ssize_t state_id, i
350        cdef bint ignore_case
351        states = {}
352        states_list = []
353        for state_id in range(self.node_count):
354            state = states[state_id] = {'id': state_id}
355            states_list.append(state)
356            c_node = c_start_node + state_id
357            if c_node.matches:
358                state['m'] = matches = []
359                match = c_node.matches
360                while match[0]:
361                    matches.append(<unicode>match[0])
362                    match += 1
363
364        # create child links
365        ignore_case = self._ignore_case
366        for state_id in range(self.node_count):
367            c_node = c_start_node + state_id
368            if not c_node.char_count:
369                continue
370            state = states[state_id]
371            state['c'] = children = []
372            for i in range(c_node.char_count):
373                ch = c_node.characters[i]
374                if ignore_case and ch.isupper():
375                    # ignore upper case characters, assuming that lower case exists as well
376                    continue
377                c_child = c_node.targets[i]
378                child_id = c_child - c_start_node
379                children.append((ch, child_id))
380
381        return _unpickle, (self.__class__, states_list, self._ignore_case,)
382
383    cpdef finditer(self, unicode data):
384        """Iterate over all occurrences of any keyword in the string.
385
386        Returns (keyword, offset) pairs.
387        """
388        if self.start_node.char_count == 0:
389            return iter(())
390        return _UnicodeAcoraIter(self, data)
391
392    def findall(self, unicode data):
393        """Find all occurrences of any keyword in the string.
394
395        Returns a list of (keyword, offset) pairs.
396        """
397        return list(self.finditer(data))
398
399
400def _unpickle(type cls not None, list states_list not None, bint ignore_case):
401    if not issubclass(cls, (UnicodeAcora, BytesAcora)):
402        raise ValueError(
403            "Invalid machine class, expected UnicodeAcora or BytesAcora, got %s" % cls.__name__)
404
405    cdef Py_ssize_t i
406    states = {i: build_MachineState(i) for i in range(len(states_list))}
407    start_state = states[0]
408    for state_data in states_list:
409        state = states[state_data['id']]
410        state.matches = state_data.get('m')
411        state.children = children = []
412        for character, child_id in state_data.get('c', ()):
413            child = states[child_id]
414            child.letter = character
415            children.append(child)
416
417    return cls(_Machine(start_state, ignore_case=ignore_case))
418
419
420cdef class _UnicodeAcoraIter:
421    cdef _AcoraUnicodeNodeStruct* current_node
422    cdef _AcoraUnicodeNodeStruct* start_node
423    cdef Py_ssize_t data_pos, data_len, match_index
424    cdef unicode data
425    cdef UnicodeAcora acora
426    cdef void* data_start
427    cdef int unicode_kind
428
429    def __cinit__(self, UnicodeAcora acora not None, unicode data not None):
430        assert acora.start_node is not NULL
431        assert acora.start_node.matches is NULL
432        self.acora = acora
433        self.start_node = self.current_node = acora.start_node
434        self.match_index = 0
435        self.data = data
436        self.data_pos = 0
437        if PyUnicode_IS_READY(data):
438            # PEP393 Unicode string
439            self.data_start = PyUnicode_DATA(data)
440            self.data_len = PyUnicode_GET_LENGTH(data)
441            self.unicode_kind = PyUnicode_KIND(data)
442        else:
443            # pre-/non-PEP393 Unicode string
444            self.data_start = PyUnicode_AS_UNICODE(data)
445            self.data_len = PyUnicode_GET_SIZE(data)
446            self.unicode_kind = PyUnicode_WCHAR_KIND
447
448        if not acora.start_node.char_count:
449            raise ValueError("Non-empty engine required")
450
451    def __iter__(self):
452        return self
453
454    def __next__(self):
455        cdef void* data_start = self.data_start
456        cdef Py_UCS4* test_chars
457        cdef Py_UCS4 current_char
458        cdef int i, found = 0, start, mid, end
459        cdef Py_ssize_t data_len = self.data_len, data_pos = self.data_pos
460        cdef _AcoraUnicodeNodeStruct* start_node = self.start_node
461        cdef _AcoraUnicodeNodeStruct* current_node = self.current_node
462
463        if current_node.matches is not NULL:
464            if current_node.matches[self.match_index] is not NULL:
465                return self._build_next_match()
466            self.match_index = 0
467
468        kind = self.unicode_kind
469        with nogil:
470            while data_pos < data_len:
471                current_char = PyUnicode_READ(kind, data_start, data_pos)
472                data_pos += 1
473                current_node = _step_to_next_node(start_node, current_node, current_char)
474                if current_node.matches is not NULL:
475                    found = 1
476                    break
477        self.data_pos = data_pos
478        self.current_node = current_node
479        if found:
480            return self._build_next_match()
481        raise StopIteration
482
483    cdef _build_next_match(self):
484        match = <unicode> self.current_node.matches[self.match_index]
485        self.match_index += 1
486        return match, self.data_pos - len(match)
487
488
489# bytes data handling
490
491cdef class BytesAcora:
492    """Acora search engine for byte data.
493    """
494    cdef _AcoraBytesNodeStruct* start_node
495    cdef Py_ssize_t node_count
496    cdef tuple _pyrefs
497    cdef bint _ignore_case
498
499    def __cinit__(self, start_state, dict transitions=None):
500        cdef _Machine machine
501        cdef _AcoraBytesNodeStruct* c_nodes
502        cdef _AcoraBytesNodeStruct* c_node
503        cdef Py_ssize_t i
504
505        if transitions is not None:
506            # old pickle format => rebuild trie
507            machine = _convert_old_format(transitions)
508        else:
509            machine = start_state
510        ignore_case = self._ignore_case = machine.ignore_case
511        self.node_count = len(machine.child_states) + 1
512
513        c_nodes = self.start_node = <_AcoraBytesNodeStruct*> cpython.mem.PyMem_Malloc(
514            sizeof(_AcoraBytesNodeStruct) * self.node_count)
515        if c_nodes is NULL:
516            raise MemoryError()
517
518        for c_node in c_nodes[:self.node_count]:
519            # required by __dealloc__ in case of subsequent errors
520            c_node.targets = NULL
521
522        node_offsets = {state: i for i, state in enumerate(machine.child_states, 1)}
523        node_offsets[machine.start_state] = 0
524        pyrefs = {}  # used to keep Python references alive (and intern them)
525
526        _init_bytes_node(c_nodes, machine.start_state, c_nodes, node_offsets, pyrefs, ignore_case)
527        for i, state in enumerate(machine.child_states, 1):
528            _init_bytes_node(c_nodes + i, state, c_nodes, node_offsets, pyrefs, ignore_case)
529        self._pyrefs = tuple(pyrefs)
530
531    def __dealloc__(self):
532        cdef Py_ssize_t i
533        if self.start_node is not NULL:
534            for i in range(self.node_count):
535                if self.start_node[i].targets is not NULL:
536                    cpython.mem.PyMem_Free(self.start_node[i].targets)
537            cpython.mem.PyMem_Free(self.start_node)
538
539    def __reduce__(self):
540        """pickle"""
541        cdef _AcoraBytesNodeStruct* c_node
542        cdef _AcoraBytesNodeStruct* c_child
543        cdef _AcoraBytesNodeStruct* c_start_node = self.start_node
544        cdef Py_ssize_t state_id, i
545        cdef bint ignore_case
546
547        states = {}
548        states_list = []
549        for state_id in range(self.node_count):
550            state = states[state_id] = {'id': state_id}
551            states_list.append(state)
552            c_node = c_start_node + state_id
553            if c_node.matches:
554                state['m'] = matches = []
555                match = c_node.matches
556                while match[0]:
557                    matches.append(<unicode>match[0])
558                    match += 1
559
560        # create child links
561        ignore_case = self._ignore_case
562        for state_id in range(self.node_count):
563            c_node = c_start_node + state_id
564            if not c_node.char_count:
565                continue
566            state = states[state_id]
567            state['c'] = children = []
568            for i in range(c_node.char_count):
569                ch = c_node.characters[i]
570                if ignore_case and ch.isupper():
571                    # ignore upper case characters, assuming that lower case exists as well
572                    continue
573                c_child = c_node.targets[i]
574                child_id = c_child - c_start_node
575                children.append((ch, child_id))
576
577        return _unpickle, (self.__class__, states_list, self._ignore_case,)
578
579    cpdef finditer(self, bytes data):
580        """Iterate over all occurrences of any keyword in the string.
581
582        Returns (keyword, offset) pairs.
583        """
584        if self.start_node.char_count == 0:
585            return iter(())
586        return _BytesAcoraIter(self, data)
587
588    def findall(self, bytes data):
589        """Find all occurrences of any keyword in the string.
590
591        Returns a list of (keyword, offset) pairs.
592        """
593        return list(self.finditer(data))
594
595    def filefind(self, f):
596        """Iterate over all occurrences of any keyword in a file.
597
598        The file must be either a file path, a file opened in binary mode
599        or a file-like object returning bytes objects on .read().
600
601        Returns (keyword, offset) pairs.
602        """
603        if self.start_node.char_count == 0:
604            return iter(())
605        close_file = False
606        if not hasattr(f, 'read'):
607            f = open(f, 'rb')
608            close_file = True
609        return _FileAcoraIter(self, f, close_file)
610
611    def filefindall(self, f):
612        """Find all occurrences of any keyword in a file.
613
614        Returns a list of (keyword, offset) pairs.
615        """
616        return list(self.filefind(f))
617
618
619cdef class _BytesAcoraIter:
620    cdef _AcoraBytesNodeStruct* current_node
621    cdef _AcoraBytesNodeStruct* start_node
622    cdef Py_ssize_t match_index
623    cdef bytes data
624    cdef BytesAcora acora
625    cdef unsigned char* data_char
626    cdef unsigned char* data_end
627    cdef unsigned char* data_start
628
629    def __cinit__(self, BytesAcora acora not None, bytes data):
630        assert acora.start_node is not NULL
631        assert acora.start_node.matches is NULL
632        self.acora = acora
633        self.start_node = self.current_node = acora.start_node
634        self.match_index = 0
635        self.data_char = self.data_start = self.data = data
636        self.data_end = self.data_char + len(data)
637
638        if not acora.start_node.char_count:
639            raise ValueError("Non-empty engine required")
640
641    def __iter__(self):
642        return self
643
644    def __next__(self):
645        cdef unsigned char* data_char = self.data_char
646        cdef unsigned char* data_end = self.data_end
647        cdef unsigned char* test_chars
648        cdef unsigned char current_char
649        cdef int i, found = 0
650        if self.current_node.matches is not NULL:
651            if self.current_node.matches[self.match_index] is not NULL:
652                return self._build_next_match()
653            self.match_index = 0
654        with nogil:
655            found = _search_in_bytes(self.start_node, data_end,
656                                     &self.data_char, &self.current_node)
657        if found:
658            return self._build_next_match()
659        raise StopIteration
660
661    cdef _build_next_match(self):
662        match = <bytes> self.current_node.matches[self.match_index]
663        self.match_index += 1
664        return (match, <Py_ssize_t>(self.data_char - self.data_start) - len(match))
665
666
667cdef int _search_in_bytes(_AcoraBytesNodeStruct* start_node,
668                          unsigned char* data_end,
669                          unsigned char** _data_char,
670                          _AcoraBytesNodeStruct** _current_node) nogil:
671    cdef unsigned char* data_char = _data_char[0]
672    cdef _AcoraBytesNodeStruct* current_node = _current_node[0]
673    cdef unsigned char current_char
674    cdef int found = 0
675
676    while data_char < data_end:
677        current_char = data_char[0]
678        data_char += 1
679        current_node = _step_to_next_node(start_node, current_node, current_char)
680        if current_node.matches is not NULL:
681            found = 1
682            break
683    _data_char[0] = data_char
684    _current_node[0] = current_node
685    return found
686
687
688ctypedef fused _AcoraNodeStruct:
689    _AcoraBytesNodeStruct
690    _AcoraUnicodeNodeStruct
691
692ctypedef fused _inputCharType:
693    unsigned char
694    Py_UCS4
695
696
697@cython.cdivision(True)
698cdef inline _AcoraNodeStruct* _step_to_next_node(
699        _AcoraNodeStruct* start_node,
700        _AcoraNodeStruct* current_node,
701        _inputCharType current_char) nogil:
702
703    cdef _inputCharType* test_chars = <_inputCharType*>current_node.characters
704    cdef int i, start, mid, end
705
706    end = current_node.char_count
707    if current_char <= test_chars[0]:
708        return current_node.targets[0] if current_char == test_chars[0] else start_node
709
710    if current_char >= test_chars[end-1]:
711        return current_node.targets[end-1] if current_char == test_chars[end-1] else start_node
712
713    # bisect into larger character maps (> 8 seems to perform best for me)
714    start = 0
715    while end - start > 8:
716        mid = (start + end) // 2
717        if current_char < test_chars[mid]:
718            end = mid
719        elif current_char == test_chars[mid]:
720            return current_node.targets[mid]
721        else:
722            start = mid
723
724    # sequentially run through small character maps
725    for i in range(start, end):
726        if current_char <= test_chars[i]:
727            return current_node.targets[i] if current_char == test_chars[i] else start_node
728
729    return start_node
730
731
732# file data handling
733
734cdef class _FileAcoraIter:
735    cdef _AcoraBytesNodeStruct* current_node
736    cdef _AcoraBytesNodeStruct* start_node
737    cdef Py_ssize_t match_index, read_size, buffer_offset_count
738    cdef bytes buffer
739    cdef unsigned char* c_buffer_pos
740    cdef unsigned char* c_buffer_end
741    cdef object f
742    cdef bint close_file
743    cdef int c_file
744    cdef BytesAcora acora
745
746    def __cinit__(self, BytesAcora acora not None, f, bint close=False, Py_ssize_t buffer_size=FILE_BUFFER_SIZE):
747        assert acora.start_node is not NULL
748        assert acora.start_node.matches is NULL
749        self.acora = acora
750        self.start_node = self.current_node = acora.start_node
751        self.match_index = 0
752        self.buffer_offset_count = 0
753        self.f = f
754        self.close_file = close
755        try:
756            self.c_file = f.fileno() if f.tell() == 0 else -1
757        except:
758            # maybe not a C file?
759            self.c_file = -1
760        self.read_size = buffer_size
761        if self.c_file == -1:
762            self.buffer = b''
763        else:
764            # use a statically allocated, fixed-size C buffer
765            self.buffer = b'\0' * buffer_size
766        self.c_buffer_pos = self.c_buffer_end = <unsigned char*> self.buffer
767
768        if not acora.start_node.char_count:
769            raise ValueError("Non-empty engine required")
770
771    def __iter__(self):
772        return self
773
774    def __next__(self):
775        cdef bytes buffer
776        cdef unsigned char* c_buffer
777        cdef unsigned char* data_end
778        cdef int error = 0, found = 0
779        cdef Py_ssize_t buffer_size, bytes_read = 0
780        if self.c_buffer_pos is NULL:
781            raise StopIteration
782        if self.current_node.matches is not NULL:
783            if self.current_node.matches[self.match_index] is not NULL:
784                return self._build_next_match()
785            self.match_index = 0
786
787        buffer_size = len(self.buffer)
788        c_buffer = <unsigned char*> self.buffer
789        if self.c_file != -1:
790            with nogil:
791                found = _find_next_match_in_cfile(
792                    self.c_file, c_buffer, buffer_size, self.start_node,
793                    &self.c_buffer_pos, &self.c_buffer_end,
794                    &self.buffer_offset_count, &self.current_node, &error)
795            if error:
796                cpython.exc.PyErr_SetFromErrno(IOError)
797        else:
798            # Why not always release the GIL and only acquire it when reading?
799            # Well, it's actually doing that.  When the search finds something,
800            # we have to acquire the GIL in order to return the result, and if
801            # it does not find anything, then we have to acquire the GIL in order
802            # to read more data.  So, wrapping the search call in a nogil section
803            # is actually enough.
804            data_end = c_buffer + buffer_size
805            while not found:
806                if self.c_buffer_pos >= data_end:
807                    self.buffer_offset_count += buffer_size
808                    self.buffer = self.f.read(self.read_size)
809                    buffer_size = len(self.buffer)
810                    if buffer_size == 0:
811                        self.c_buffer_pos = NULL
812                        break
813                    c_buffer = self.c_buffer_pos = <unsigned char*> self.buffer
814                    data_end = c_buffer + buffer_size
815                with nogil:
816                    found = _search_in_bytes(
817                        self.start_node, data_end,
818                        &self.c_buffer_pos, &self.current_node)
819        if self.c_buffer_pos is NULL:
820            if self.close_file:
821                self.f.close()
822        elif found:
823            return self._build_next_match()
824        raise StopIteration
825
826    cdef _build_next_match(self):
827        match = <bytes> self.current_node.matches[self.match_index]
828        self.match_index += 1
829        return (match, self.buffer_offset_count + (
830                self.c_buffer_pos - (<unsigned char*> self.buffer)) - len(match))
831
832
833cdef int _find_next_match_in_cfile(int c_file, unsigned char* c_buffer, size_t buffer_size,
834                                   _AcoraBytesNodeStruct* start_node,
835                                   unsigned char** _buffer_pos, unsigned char** _buffer_end,
836                                   Py_ssize_t* _buffer_offset_count,
837                                   _AcoraBytesNodeStruct** _current_node,
838                                   int* error) nogil:
839    cdef unsigned char* buffer_pos = _buffer_pos[0]
840    cdef unsigned char* buffer_end = _buffer_end[0]
841    cdef unsigned char* data_end = c_buffer + buffer_size
842    cdef Py_ssize_t buffer_offset_count = _buffer_offset_count[0]
843    cdef _AcoraBytesNodeStruct* current_node = _current_node[0]
844    cdef int found = 0
845    cdef Py_ssize_t bytes_read
846
847    while not found:
848        if buffer_pos >= buffer_end:
849            buffer_offset_count += buffer_end - c_buffer
850            bytes_read = read(c_file, c_buffer, buffer_size)
851            if bytes_read <= 0:
852                if bytes_read < 0:
853                    error[0] = 1
854                buffer_pos = NULL
855                break
856            buffer_pos = c_buffer
857            buffer_end = c_buffer + bytes_read
858
859        found = _search_in_bytes(
860            start_node, buffer_end, &buffer_pos, &current_node)
861
862    _current_node[0] = current_node
863    _buffer_offset_count[0] = buffer_offset_count
864    _buffer_pos[0] = buffer_pos
865    _buffer_end[0] = buffer_end
866    return found
867