1######################################################################
2# Copyright (c) 2001-2013
3#   John Holland <john@zoner.org>
4# All rights reserved.
5#
6# This software is licensed as described in the file LICENSE.txt, which
7# you should have received as part of this distribution.
8#
9######################################################################
10
11"""
12Walk a tree of x12_map nodes.  Find the correct node.
13
14If seg indicates a loop has been entered, returns the first child segment node.
15If seg indicates a segment has been entered, returns the segment node.
16"""
17
18import logging
19
20# Intrapackage imports
21from .errors import EngineError
22import pyx12.segment
23from .nodeCounter import NodeCounter
24
25logger = logging.getLogger('pyx12.walk_tree')
26#logger.setLevel(logging.DEBUG)
27#logger.setLevel(logging.ERROR)
28
29
30def pop_to_parent_loop(node):
31    """
32    @param node: Loop Node
33    @type node: L{node<map_if.x12_node>}
34    @return: Closest parent loop node
35    @rtype: L{node<map_if.x12_node>}
36    """
37    if node.is_map_root():
38        return node
39    map_node = node.parent
40    if map_node is None:
41        raise EngineError("Node is None: %s" % (node.name))
42    while not (map_node.is_loop() or map_node.is_map_root()):
43        map_node = map_node.parent
44    if not (map_node.is_loop() or map_node.is_map_root()):
45        raise EngineError("Called pop_to_parent_loop, can't find parent loop")
46    return map_node
47
48
49def is_first_seg_match2(child, seg_data):
50    """
51    Find the first segment in loop, verify it matches segment
52
53    @param child: child node
54    @type child: L{node<map_if.x12_node>}
55    @param seg_data: Segment object
56    @type seg_data: L{segment<segment.Segment>}
57    @rtype: boolean
58    """
59    if child.is_segment():
60        if child.is_match(seg_data):
61            return True
62        else:
63            # seg does not match the first segment in loop, so not valid
64            return False
65    return False
66
67
68def get_id_list(node_list):
69    # get_id_list(pop)
70    ret = []
71    for node in node_list:
72        if node is not None:
73            ret.append(node.id)
74    return ret
75
76
77def traverse_path(start_node, pop_loops, push_loops):
78    """
79    Debug function - From the start path, pop up then push down to get a path string
80    """
81    start_path = pop_to_parent_loop(start_node).get_path()
82    p1 = [p for p in start_path.split('/') if p != '']
83    for loop_id in get_id_list(pop_loops):
84        assert loop_id == p1[-1], 'Path %s does not contain %s' % (start_path, loop_id)
85        p1 = p1[:-1]
86    for loop_id in get_id_list(push_loops):
87        p1.append(loop_id)
88    return '/' + '/'.join(p1)
89
90
91class walk_tree(object):
92    """
93    Walks a map_if tree.  Tracks loop/segment counting, missing loop/segment.
94    """
95    def __init__(self, initialCounts=None):
96        # Store errors until we know we have an error
97        self.mandatory_segs_missing = []
98        if initialCounts is None:
99            initialCounts = {}
100        self.counter = NodeCounter(initialCounts)
101
102    def walk(self, node, seg_data, errh, seg_count, cur_line, ls_id):
103        """
104        Walk the node tree from the starting node to the node matching
105        seg_data. Catch any counting or requirement errors along the way.
106
107        Handle required segment/loop missed (not found in seg)
108        Handle found segment = Not used
109
110        @param node: Starting node
111        @type node: L{node<map_if.x12_node>}
112        @param seg_data: Segment object
113        @type seg_data: L{segment<segment.Segment>}
114        @param seg_count: Count of current segment in the ST Loop
115        @type seg_count: int
116        @param cur_line: Current line number in the file
117        @type cur_line: int
118        @param ls_id: The current LS loop identifier
119        @type ls_id: string
120        @return: The matching x12 segment node, a list of x12 popped loops, and a list
121            of x12 pushed loops from the start segment to the found segment
122        @rtype: (L{node<map_if.segment_if>}, [L{node<map_if.loop_if>}], [L{node<map_if.loop_if>}])
123
124        @todo: check single segment loop repeat
125        """
126        pop_node_list = []
127        push_node_list = []
128        orig_node = node
129        #logger.info('%s seg_count=%i / cur_line=%i' % (node.id, seg_count, cur_line))
130        self.mandatory_segs_missing = []
131        node_pos = node.pos  # Get original position ordinal of starting node
132        if not (node.is_loop() or node.is_map_root()):
133            node = pop_to_parent_loop(node)  # Get enclosing loop
134            #node_list.append(node)
135        while True:
136            # Iterate through nodes with position >= current position
137            for ord1 in [a for a in sorted(node.pos_map) if a >= node_pos]:
138                for child in node.pos_map[ord1]:
139                    if child.is_segment():
140                        if child.is_match(seg_data):
141                            # Is the matched segment the beginning of a loop?
142                            if node.is_loop() \
143                                    and self._is_loop_match(node, seg_data, errh, seg_count, cur_line, ls_id):
144                                (
145                                    node1, push_node_list) = self._goto_seg_match(node, seg_data,
146                                                                                  errh, seg_count, cur_line, ls_id)
147                                if orig_node.is_loop() or orig_node.is_map_root():
148                                    orig_loop = orig_node
149                                else:
150                                    orig_loop = pop_to_parent_loop(orig_node)  # Get enclosing loop
151                                if node == orig_loop:
152                                    pop_node_list = [node]
153                                    push_node_list = [node]
154                                return (node1, pop_node_list, push_node_list)  # segment node
155                            #child.incr_cur_count()
156                            self.counter.increment(child.x12path)
157                            #assert child.get_cur_count() == self.counter.get_count(child.x12path), \
158                            #    'child counts not equal: old is %s=%i : new is %s=%i' % (
159                            #    child.get_path(), child.get_cur_count(),
160                            #    child.x12path.format(), self.counter.get_count(child.x12path))
161                            self._check_seg_usage(child, seg_data, seg_count, cur_line, ls_id, errh)
162                            # Remove any previously missing errors for this segment
163                            self.mandatory_segs_missing = [x for x in self.mandatory_segs_missing if x[0] != child]
164                            self._flush_mandatory_segs(errh, child.pos)
165                            return (child, pop_node_list, push_node_list)  # segment node
166                        elif child.usage == 'R' and self.counter.get_count(child.x12path) < 1:
167                            fake_seg = pyx12.segment.Segment('%s' % (child.id), '~', '*', ':')
168                            err_str = 'Mandatory segment "%s" (%s) missing' % (child.name, child.id)
169                            self.mandatory_segs_missing.append((child, fake_seg, '3', err_str, seg_count, cur_line, ls_id))
170                        #else:
171                            #logger.debug('Segment %s is not a match for (%s*%s)' % \
172                            #   (child.id, seg_data.get_seg_id(), seg_data[0].get_value()))
173                    elif child.is_loop():
174                        if self._is_loop_match(child, seg_data, errh, seg_count, cur_line, ls_id):
175                            (node_seg, push_node_list) = self._goto_seg_match(child, seg_data, errh, seg_count, cur_line, ls_id)
176                            return (node_seg, pop_node_list, push_node_list)  # segment node
177            # End for ord1 in pos_keys
178            if node.is_map_root():  # If at root and we haven't found the segment yet.
179                walk_tree._seg_not_found_error(orig_node, seg_data,
180                                               errh, seg_count, cur_line, ls_id)
181                return (None, [], [])
182            node_pos = node.pos  # Get position ordinal of current node in tree
183            pop_node_list.append(node)
184            node = pop_to_parent_loop(node)  # Get enclosing parent loop
185
186        walk_tree._seg_not_found_error(orig_node, seg_data, errh, seg_count, cur_line, ls_id)
187        return (None, [], [])
188
189    def getCountState(self):
190        return self.counter.getState()
191
192    def setCountState(self, initialCounts={}):
193        self.counter = NodeCounter(initialCounts)
194
195    def forceWalkCounterToLoopStart(self, x12_path, child_path):
196        # delete child counts under the x12_path, no longer needed
197        self.counter.reset_to_node(x12_path)
198        self.counter.increment(x12_path)  # add a count for this path
199        self.counter.increment(child_path) # count the loop start segment
200
201    def _check_seg_usage(self, seg_node, seg_data, seg_count, cur_line, ls_id, errh):
202        """
203        Check segment usage requirement and count
204
205        @param seg_node: Segment X12 node to verify
206        @type seg_node: L{node<map_if.segment_if>}
207        @param seg_data: Segment object
208        @type seg_data: L{segment<segment.Segment>}
209        @param seg_count: Count of current segment in the ST Loop
210        @type seg_count: int
211        @param cur_line: Current line number in the file
212        @type cur_line: int
213        @param ls_id: The current LS loop identifier
214        @type ls_id: string
215        @param errh: Error handler
216        @type errh: L{error_handler.err_handler}
217        @raise EngineError: On invalid usage code
218        """
219        assert seg_node.usage in ('N', 'R', 'S'), 'Segment usage must be R, S, or N'
220        if seg_node.usage == 'N':
221            err_str = "Segment %s found but marked as not used" % (seg_node.id)
222            errh.seg_error('2', err_str, None)
223        elif seg_node.usage == 'R' or seg_node.usage == 'S':
224            #assert seg_node.get_cur_count() == self.counter.get_count(seg_node.x12path), 'seg_node counts not equal'
225            if self.counter.get_count(seg_node.x12path) > seg_node.get_max_repeat():  # handle seg repeat count
226                err_str = "Segment %s exceeded max count.  Found %i, should have %i" \
227                    % (seg_data.get_seg_id(), self.counter.get_count(seg_node.x12path), seg_node.get_max_repeat())
228                errh.add_seg(seg_node, seg_data, seg_count, cur_line, ls_id)
229                errh.seg_error('5', err_str, None)
230
231    @staticmethod
232    def _seg_not_found_error(orig_node, seg_data, errh, seg_count, cur_line, ls_id):
233        """
234        Create error for not found segments
235
236        @param orig_node: Original starting node
237        @type orig_node: L{node<map_if.x12_node>}
238        @param seg_data: Segment object
239        @type seg_data: L{segment<segment.Segment>}
240        @param errh: Error handler
241        @type errh: L{error_handler.err_handler}
242        """
243        if seg_data.get_seg_id() == 'HL':
244            seg_str = seg_data.format('', '*', ':')
245        else:
246            seg_str = '%s*%s' % (seg_data.get_seg_id(), seg_data.get_value('01'))
247        err_str = 'Segment %s not found.  Started at %s' % (seg_str, orig_node.get_path())
248        errh.add_seg(orig_node, seg_data, seg_count, cur_line, ls_id)
249        errh.seg_error('1', err_str, None)
250
251    def _flush_mandatory_segs(self, errh, cur_pos=None):
252        """
253        Handle error reporting for any outstanding missing mandatory segments
254
255        @param errh: Error handler
256        @type errh: L{error_handler.err_handler}
257        """
258        for (seg_node, seg_data, err_cde, err_str, seg_count, cur_line, ls_id) in self.mandatory_segs_missing:
259            # Create errors if not also at current position
260            if seg_node.pos != cur_pos:
261                errh.add_seg(seg_node, seg_data, seg_count, cur_line, ls_id)
262                errh.seg_error(err_cde, err_str, None)
263        self.mandatory_segs_missing = [x for x in self.mandatory_segs_missing if x[0].pos == cur_pos]
264
265    def _is_loop_match(self, loop_node, seg_data, errh, seg_count, cur_line, ls_id):
266        """
267        Try to match the current loop to the segment
268        Handle loop and segment counting.
269        Check for used/missing
270
271        @param loop_node: Loop Node
272        @type loop_node: L{node<map_if.loop_if>}
273        @param seg_data: Segment object
274        @type seg_data: L{segment<segment.Segment>}
275        @param errh: Error handler
276        @type errh: L{error_handler.err_handler}
277
278        @return: Does the segment match the first segment node in the loop?
279        @rtype: boolean
280        """
281        assert loop_node.is_loop(), "Call to first_seg_match failed, node %s is not a loop. seg %s" \
282            % (loop_node.id, seg_data.get_seg_id())
283        #if loop_node.id not in ('ISA_LOOP', 'GS_LOOP'):
284        #    assert loop_node.get_cur_count() == self.counter.get_count(loop_node.x12path), \
285        #        'loop_node counts not equal: old is %s=%i : new is %s=%i' % (
286        #        loop_node.get_path(), loop_node.get_cur_count(),
287        #        loop_node.x12path.format(), self.counter.get_count(loop_node.x12path))
288        if len(loop_node) <= 0:  # Has no children
289            return False
290        first_child_node = loop_node.get_first_node()
291        assert first_child_node is not None, 'get_first_node failed from loop %s' % (loop_node.id)
292        if first_child_node.is_loop():
293            #If any loop node matches
294            for child_node in loop_node.childIterator():
295                if child_node.is_loop() and self._is_loop_match(child_node,
296                                                                seg_data, errh, seg_count, cur_line, ls_id):
297                    return True
298        elif is_first_seg_match2(first_child_node, seg_data):
299            return True
300        elif loop_node.usage == 'R' and self.counter.get_count(loop_node.x12path) < 1:
301            fake_seg = pyx12.segment.Segment('%s' % (first_child_node.id), '~', '*', ':')
302            err_str = 'Mandatory loop "%s" (%s) missing' % \
303                (loop_node.name, loop_node.id)
304            self.mandatory_segs_missing.append((first_child_node, fake_seg,
305                                                '3', err_str, seg_count, cur_line, ls_id))
306        return False
307
308    def _goto_seg_match(self, loop_node, seg_data, errh, seg_count, cur_line, ls_id):
309        """
310        A child loop has matched the segment.  Return that segment node.
311        Handle loop counting and requirement errors.
312
313        @param loop_node: The starting loop node.
314        @type loop_node: L{node<map_if.loop_if>}
315        @param seg_data: Segment object
316        @type seg_data: L{segment<segment.Segment>}
317        @param errh: Error handler
318        @type errh: L{error_handler.err_handler}
319        @param seg_count: Current segment count for ST loop
320        @type seg_count: int
321        @param cur_line: File line counter
322        @type cur_line: int
323        @type ls_id: string
324
325        @return: The matching segment node and a list of the push loop nodes
326        @rtype: (L{node<map_if.segment_if>}, [L{node<map_if.loop_if>}])
327        """
328        assert loop_node.is_loop(), "_goto_seg_match failed, node %s is not a loop. seg %s" \
329            % (loop_node.id, seg_data.get_seg_id())
330        first_child_node = loop_node.get_first_seg()
331        if first_child_node is not None and is_first_seg_match2(first_child_node, seg_data):
332            self._check_loop_usage(loop_node, seg_data,
333                                   seg_count, cur_line, ls_id, errh)
334            #first_child_node.incr_cur_count()
335            self.counter.increment(first_child_node.x12path)
336            #assert first_child_node.get_cur_count() == self.counter.get_count(first_child_node.x12path), 'first_child_node counts not equal'
337            self._flush_mandatory_segs(errh)
338            return (first_child_node, [loop_node])
339        else:
340            for child in loop_node.childIterator():
341                if child.is_loop():
342                    (
343                        node1, push1) = self._goto_seg_match(child, seg_data, errh,
344                                                             seg_count, cur_line, ls_id)
345                    if node1:
346                        push_node_list = [loop_node]
347                        push_node_list.extend(push1)
348                        return (node1, push_node_list)
349        return (None, [])
350
351    def _check_loop_usage(self, loop_node, seg_data, seg_count, cur_line, ls_id, errh):
352        """
353        Check loop usage requirement and count
354
355        @param loop_node: Loop X12 node to verify
356        @type loop_node: L{node<map_if.loop_if>}
357        @param seg_data: Segment object
358        @type seg_data: L{segment<segment.Segment>}
359        @param seg_count: Count of current segment in the ST Loop
360        @type seg_count: int
361        @param cur_line: Current line number in the file
362        @type cur_line: int
363        @param ls_id: The current LS loop identifier
364        @type ls_id: string
365        @param errh: Error handler
366        @type errh: L{error_handler.err_handler}
367        @raise EngineError: On invalid usage code
368        """
369        assert loop_node.is_loop(), "Node %s is not a loop. seg %s" % (
370            loop_node.id, seg_data.get_seg_id())
371        assert loop_node.usage in ('N', 'R', 'S'), 'Loop usage must be R, S, or N'
372        if loop_node.usage == 'N':
373            err_str = "Loop %s found but marked as not used" % (loop_node.id)
374            errh.seg_error('2', err_str, None)
375        elif loop_node.usage in ('R', 'S'):
376            #if loop_node.id == '2110':
377            #    import ipdb; ipdb.set_trace()
378            #loop_node.reset_child_count()
379            self.counter.reset_to_node(loop_node.x12path)
380            #loop_node.incr_cur_count()
381            self.counter.increment(loop_node.x12path)
382            #assert loop_node.get_cur_count() == self.counter.get_count(loop_node.x12path), \
383            #    'loop_node counts not equal: old is %s=%i : new is %s=%i' % (
384            #    loop_node.get_path(), loop_node.get_cur_count(),
385            #    loop_node.x12path.format(), self.counter.get_count(loop_node.x12path))
386            #logger.debug('incr loop_node %s %i' % (loop_node.id, loop_node.cur_count))
387            #logger.debug('incr first_child_node %s %i' % (first_child_node.id, first_child_node.cur_count))
388            if self.counter.get_count(loop_node.x12path) > loop_node.get_max_repeat():
389                err_str = "Loop %s exceeded max count.  Found %i, should have %i" \
390                    % (loop_node.id, self.counter.get_count(loop_node.x12path), loop_node.get_max_repeat())
391                errh.add_seg(loop_node, seg_data, seg_count, cur_line, ls_id)
392                errh.seg_error('4', err_str, None)
393            #logger.debug('MATCH Loop %s / Segment %s (%s*%s)' \
394            #    % (child.id, first_child_node.id, seg_data.get_seg_id(), seg[0].get_value()))
395