1# -*- coding: utf-8 -*-
2#  Copyright 2011 Takeshi KOMIYA
3#
4#  Licensed under the Apache License, Version 2.0 (the "License");
5#  you may not use this file except in compliance with the License.
6#  You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10#  Unless required by applicable law or agreed to in writing, software
11#  distributed under the License is distributed on an "AS IS" BASIS,
12#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13#  See the License for the specific language governing permissions and
14#  limitations under the License.
15
16from blockdiag import parser
17from blockdiag.elements import Diagram, DiagramEdge, DiagramNode, NodeGroup
18from blockdiag.plugins import fire_node_event
19from blockdiag.utils import XY, unquote
20from blockdiag.utils.compat import cmp_to_key
21
22
23class DiagramTreeBuilder:
24    def build(self, tree, config):
25        self.config = config
26        self.diagram = Diagram()
27        self.instantiate(self.diagram, tree)
28        for subgroup in self.diagram.traverse_groups():
29            if len(subgroup.nodes) == 0:
30                subgroup.group.nodes.remove(subgroup)
31
32        self.bind_edges(self.diagram)
33        self.fire_node_event('build_finished')
34        return self.diagram
35
36    def fire_node_event(self, event_type):
37        for node in self.diagram.nodes:
38            if node.drawable:
39                fire_node_event(node, event_type)
40
41    def is_related_group(self, group1, group2):
42        if group1.is_parent(group2) or group2.is_parent(group1):
43            return True
44        else:
45            return False
46
47    def belong_to(self, node, group):
48        if node.group and node.group.level > group.level:
49            override = False
50        else:
51            override = True
52
53        if node.group and node.group != group and override:
54            if not self.is_related_group(node.group, group):
55                msg = "could not belong to two groups: %s" % node.id
56                raise RuntimeError(msg)
57
58            old_group = node.group
59
60            parent = group.parent(old_group.level + 1)
61            if parent:
62                if parent in old_group.nodes:
63                    old_group.nodes.remove(parent)
64
65                index = old_group.nodes.index(node)
66                old_group.nodes.insert(index + 1, parent)
67
68            old_group.nodes.remove(node)
69            node.group = None
70
71        if node.group is None:
72            node.group = group
73
74            if node not in group.nodes:
75                group.nodes.append(node)
76
77    def instantiate(self, group, tree):
78        for stmt in tree.stmts:
79            # Translate Node having group attribute to Group
80            if isinstance(stmt, parser.Node):
81                group_attr = [a for a in stmt.attrs if a.name == 'group']
82                if group_attr:
83                    group_id = group_attr[-1]
84                    stmt.attrs.remove(group_id)
85
86                    if group_id.value != group.id:
87                        stmt = parser.Group(group_id.value, [stmt])
88
89            # Instantiate statements
90            if isinstance(stmt, parser.Node):
91                node = DiagramNode.get(stmt.id)
92                node.set_attributes(stmt.attrs)
93                self.belong_to(node, group)
94
95            elif isinstance(stmt, parser.Edge):
96                from_nodes = [DiagramNode.get(n) for n in stmt.from_nodes]
97                to_nodes = [DiagramNode.get(n) for n in stmt.to_nodes]
98
99                for node in from_nodes + to_nodes:
100                    self.belong_to(node, group)
101
102                for node1 in from_nodes:
103                    for node2 in to_nodes:
104                        edge = DiagramEdge.get(node1, node2)
105                        edge.set_dir(stmt.edge_type)
106                        edge.set_attributes(stmt.attrs)
107
108            elif isinstance(stmt, parser.Group):
109                subgroup = NodeGroup.get(stmt.id)
110                subgroup.level = group.level + 1
111                self.belong_to(subgroup, group)
112                self.instantiate(subgroup, stmt)
113
114            elif isinstance(stmt, parser.Attr):
115                group.set_attribute(stmt)
116
117            elif isinstance(stmt, parser.Extension):
118                if stmt.type == 'class':
119                    name = unquote(stmt.name)
120                    Diagram.classes[name] = stmt
121                elif stmt.type == 'plugin':
122                    self.diagram.set_plugin(stmt.name, stmt.attrs,
123                                            config=self.config)
124
125            elif isinstance(stmt, parser.Statements):
126                self.instantiate(group, stmt)
127
128        group.update_order()
129        return group
130
131    def bind_edges(self, group):
132        for node in group.nodes:
133            if isinstance(node, DiagramNode):
134                group.edges += DiagramEdge.find(node)
135            else:
136                self.bind_edges(node)
137
138
139class DiagramLayoutManager:
140    def __init__(self, diagram):
141        self.diagram = diagram
142
143        self.circulars = []
144        self.heightRefs = []
145        self.coordinates = []
146
147    def run(self):
148        if isinstance(self.diagram, Diagram):
149            for group in self.diagram.traverse_groups():
150                self.__class__(group).run()
151
152        self.edges = DiagramEdge.find_by_level(self.diagram.level)
153        self.do_layout()
154        self.diagram.fixiate()
155
156        if self.diagram.orientation == 'portrait':
157            self.rotate_diagram()
158
159    def rotate_diagram(self):
160        for node in self.diagram.traverse_nodes():
161            node.xy = XY(node.xy.y, node.xy.x)
162            node.colwidth, node.colheight = (node.colheight, node.colwidth)
163
164            if isinstance(node, NodeGroup):
165                if node.orientation == 'portrait':
166                    node.orientation = 'landscape'
167                else:
168                    node.orientation = 'portrait'
169
170        xy = (self.diagram.colheight, self.diagram.colwidth)
171        self.diagram.colwidth, self.diagram.colheight = xy
172
173    def do_layout(self):
174        self.detect_circulars()
175
176        self.set_node_xpos()
177        self.adjust_node_order()
178
179        height = 0
180        for node in self.diagram.nodes:
181            if node.xy.x == 0:
182                self.set_node_ypos(node, height)
183                height = max(xy.y for xy in self.coordinates) + 1
184
185    def get_related_nodes(self, node, parent=False, child=False):
186        uniq = {}
187        for edge in self.edges:
188            if edge.folded:
189                continue
190
191            if parent and edge.node2 == node:
192                uniq[edge.node1] = 1
193            elif child and edge.node1 == node:
194                uniq[edge.node2] = 1
195
196        related = []
197        for uniq_node in uniq.keys():
198            if uniq_node == node:
199                pass
200            elif uniq_node.group != node.group:
201                pass
202            else:
203                related.append(uniq_node)
204
205        related.sort(key=lambda x: x.order)
206        return related
207
208    def get_parent_nodes(self, node):
209        return self.get_related_nodes(node, parent=True)
210
211    def get_child_nodes(self, node):
212        return self.get_related_nodes(node, child=True)
213
214    def detect_circulars(self):
215        for node in self.diagram.nodes:
216            if not [x for x in self.circulars if node in x]:
217                self.detect_circulars_sub(node, [node])
218
219        # remove part of other circular
220        for c1 in self.circulars[:]:
221            for c2 in self.circulars:
222                intersect = set(c1) & set(c2)
223
224                if c1 != c2 and set(c1) == intersect:
225                    if c1 in self.circulars:
226                        self.circulars.remove(c1)
227                    break
228
229                if c1 != c2 and intersect:
230                    if c1 in self.circulars:
231                        self.circulars.remove(c1)
232                    self.circulars.remove(c2)
233                    self.circulars.append(c1 + c2)
234                    break
235
236    def detect_circulars_sub(self, node, parents):
237        for child in self.get_child_nodes(node):
238            if child in parents:
239                i = parents.index(child)
240                if parents[i:] not in self.circulars:
241                    self.circulars.append(parents[i:])
242            else:
243                self.detect_circulars_sub(child, parents + [child])
244
245    def is_circular_ref(self, node1, node2):
246        for circular in self.circulars:
247            if node1 in circular and node2 in circular:
248                parents = []
249                for node in circular:
250                    for parent in self.get_parent_nodes(node):
251                        if parent not in circular:
252                            parents.append(parent)
253
254                for parent in sorted(parents, key=lambda x: x.order):
255                    children = self.get_child_nodes(parent)
256                    if node1 in children and node2 in children:
257                        if circular.index(node1) > circular.index(node2):
258                            return True
259                    elif node2 in children:
260                        return True
261                    elif node1 in children:
262                        return False
263                else:
264                    if circular.index(node1) > circular.index(node2):
265                        return True
266
267        return False
268
269    def set_node_xpos(self, depth=0):
270        for node in self.diagram.nodes:
271            if node.xy.x != depth:
272                continue
273
274            for child in self.get_child_nodes(node):
275                if self.is_circular_ref(node, child):
276                    pass
277                elif node == child:
278                    pass
279                elif child.xy.x > node.xy.x + node.colwidth:
280                    pass
281                else:
282                    child.xy = XY(node.xy.x + node.colwidth, 0)
283
284        depther_node = [x for x in self.diagram.nodes if x.xy.x > depth]
285        if len(depther_node) > 0:
286            self.set_node_xpos(depth + 1)
287
288    def adjust_node_order(self):
289        for node in list(self.diagram.nodes):
290            parents = self.get_parent_nodes(node)
291            if len(set(parents)) > 1:
292                for i in range(1, len(parents)):
293                    node1 = parents[i - 1]
294                    node2 = parents[i]
295
296                    if node1.xy.x == node2.xy.x:
297                        idx1 = self.diagram.nodes.index(node1)
298                        idx2 = self.diagram.nodes.index(node2)
299
300                        if idx1 < idx2:
301                            self.diagram.nodes.remove(node2)
302                            self.diagram.nodes.insert(idx1 + 1, node2)
303                        else:
304                            self.diagram.nodes.remove(node1)
305                            self.diagram.nodes.insert(idx2 + 1, node1)
306
307            children = self.get_child_nodes(node)
308            if len(set(children)) > 1:
309                for i in range(1, len(children)):
310                    node1 = children[i - 1]
311                    node2 = children[i]
312
313                    idx1 = self.diagram.nodes.index(node1)
314                    idx2 = self.diagram.nodes.index(node2)
315
316                    if node1.xy.x == node2.xy.x:
317                        if idx1 < idx2:
318                            self.diagram.nodes.remove(node2)
319                            self.diagram.nodes.insert(idx1 + 1, node2)
320                        else:
321                            self.diagram.nodes.remove(node1)
322                            self.diagram.nodes.insert(idx2 + 1, node1)
323                    elif self.is_circular_ref(node1, node2):
324                        pass
325                    else:
326                        if node1.xy.x < node2.xy.x:
327                            self.diagram.nodes.remove(node2)
328                            self.diagram.nodes.insert(idx1 + 1, node2)
329                        else:
330                            self.diagram.nodes.remove(node1)
331                            self.diagram.nodes.insert(idx2 + 1, node1)
332
333            if isinstance(node, NodeGroup):
334                children = self.get_child_nodes(node)
335                if len(set(children)) > 1:
336                    while True:
337                        exchange = 0
338
339                        for i in range(1, len(children)):
340                            node1 = children[i - 1]
341                            node2 = children[i]
342
343                            idx1 = self.diagram.nodes.index(node1)
344                            idx2 = self.diagram.nodes.index(node2)
345                            ret = self.compare_child_node_order(node,
346                                                                node1, node2)
347
348                            if ret > 0 and idx1 < idx2:
349                                self.diagram.nodes.remove(node1)
350                                self.diagram.nodes.insert(idx2 + 1, node1)
351                                exchange += 1
352
353                        if exchange == 0:
354                            break
355
356        self.diagram.update_order()
357
358    def compare_child_node_order(self, parent, node1, node2):
359        def compare(x, y):
360            x = x.duplicate()
361            y = y.duplicate()
362            while x.node1 == y.node1 and x.node1.group is not None:
363                x.node1 = x.node1.group
364                y.node1 = y.node1.group
365
366            # cmp x.node1.order and y.node1.order
367            if x.node1.order < y.node1.order:
368                return -1
369            elif x.node1.order == y.node1.order:
370                return 0
371            else:
372                return 1
373
374        edges = (DiagramEdge.find(parent, node1) +
375                 DiagramEdge.find(parent, node2))
376        edges.sort(key=cmp_to_key(compare))
377        if len(edges) == 0:
378            return 0
379        elif edges[0].node2 == node2:
380            return 1
381        else:
382            return -1
383
384    def mark_xy(self, xy, width, height):
385        for w in range(width):
386            for h in range(height):
387                self.coordinates.append(XY(xy.x + w, xy.y + h))
388
389    def set_node_ypos(self, node, height=0):
390        for x in range(node.colwidth):
391            for y in range(node.colheight):
392                xy = XY(node.xy.x + x, height + y)
393                if xy in self.coordinates:
394                    return False
395        node.xy = XY(node.xy.x, height)
396        self.mark_xy(node.xy, node.colwidth, node.colheight)
397
398        def cmp(x, y):
399            if x.xy.x < y.xy.y:
400                return -1
401            elif x.xy.x == y.xy.y:
402                return 0
403            else:
404                return 1
405
406        count = 0
407        children = self.get_child_nodes(node)
408        children.sort(key=cmp_to_key(cmp))
409
410        grandchild = 0
411        for child in children:
412            if self.get_child_nodes(child):
413                grandchild += 1
414
415        prev_child = None
416        for child in children:
417            if child.id in self.heightRefs:
418                pass
419            elif node.xy.x >= child.xy.x:
420                pass
421            else:
422                if isinstance(node, NodeGroup):
423                    parent_height = self.get_parent_node_ypos(node, child)
424                    if parent_height and parent_height > height:
425                        height = parent_height
426
427                if (prev_child and grandchild > 1 and
428                   (not self.is_rhombus(prev_child, child))):
429                    coord = [p.y for p in self.coordinates if p.x > child.xy.x]
430                    if coord and max(coord) >= node.xy.y:
431                        height = max(coord) + 1
432
433                while True:
434                    if self.set_node_ypos(child, height):
435                        child.xy = XY(child.xy.x, height)
436                        self.mark_xy(child.xy, child.colwidth, child.colheight)
437                        self.heightRefs.append(child.id)
438
439                        count += 1
440                        break
441                    else:
442                        if count == 0:
443                            return False
444
445                        height += 1
446
447                height += 1
448                prev_child = child
449
450        return True
451
452    def is_rhombus(self, node1, node2):
453        ret = False
454        while True:
455            if node1 == node2:
456                ret = True
457                break
458
459            child1 = self.get_child_nodes(node1)
460            child2 = self.get_child_nodes(node2)
461
462            if len(child1) != 1 or len(child2) != 1:
463                break
464            elif node1.xy.x > child1[0].xy.x or node2.xy.x > child2[0].xy.x:
465                break
466            else:
467                node1 = child1[0]
468                node2 = child2[0]
469
470        return ret
471
472    def get_parent_node_ypos(self, parent, child):
473        heights = []
474        for e in DiagramEdge.find(parent, child):
475            y = parent.xy.y
476
477            node = e.node1
478            while node != parent:
479                y += node.xy.y
480                node = node.group
481
482            heights.append(y)
483
484        if heights:
485            return min(heights)
486        else:
487            return None
488
489
490class EdgeLayoutManager(object):
491    def __init__(self, diagram):
492        self.diagram = diagram
493
494    @property
495    def groups(self):
496        if self.diagram.separated:
497            seq = self.diagram.nodes
498        else:
499            seq = self.diagram.traverse_groups(preorder=True)
500
501        for group in seq:
502            if not group.drawable:
503                yield group
504
505    @property
506    def nodes(self):
507        if self.diagram.separated:
508            seq = self.diagram.nodes
509        else:
510            seq = self.diagram.traverse_nodes()
511
512        for node in seq:
513            if node.drawable:
514                yield node
515
516    @property
517    def edges(self):
518        for edge in (e for e in self.diagram.edges if e.style != 'none'):
519            yield edge
520
521        for group in self.groups:
522            for edge in (e for e in group.edges if e.style != 'none'):
523                yield edge
524
525    def run(self):
526        for edge in self.edges:
527            _dir = edge.direction
528
529            if edge.node1.group.orientation == 'landscape':
530                if _dir == 'right':
531                    r = range(edge.node1.xy.x + 1, edge.node2.xy.x)
532                    for x in r:
533                        xy = (x, edge.node1.xy.y)
534                        nodes = [x for x in self.nodes if x.xy == xy]
535                        if len(nodes) > 0:
536                            edge.skipped = 1
537                elif _dir == 'right-up':
538                    r = range(edge.node1.xy.x + 1, edge.node2.xy.x)
539                    for x in r:
540                        xy = (x, edge.node1.xy.y)
541                        nodes = [x for x in self.nodes if x.xy == xy]
542                        if len(nodes) > 0:
543                            edge.skipped = 1
544                elif _dir == 'right-down':
545                    if self.diagram.edge_layout == 'flowchart':
546                        r = range(edge.node1.xy.y, edge.node2.xy.y)
547                        for y in r:
548                            xy = (edge.node1.xy.x, y + 1)
549                            nodes = [x for x in self.nodes if x.xy == xy]
550                            if len(nodes) > 0:
551                                edge.skipped = 1
552
553                    r = range(edge.node1.xy.x + 1, edge.node2.xy.x)
554                    for x in r:
555                        xy = (x, edge.node2.xy.y)
556                        nodes = [x for x in self.nodes if x.xy == xy]
557                        if len(nodes) > 0:
558                            edge.skipped = 1
559                elif _dir in ('left-down', 'down'):
560                    r = range(edge.node1.xy.y + 1, edge.node2.xy.y)
561                    for y in r:
562                        xy = (edge.node1.xy.x, y)
563                        nodes = [x for x in self.nodes if x.xy == xy]
564                        if len(nodes) > 0:
565                            edge.skipped = 1
566                elif _dir == 'up':
567                    r = range(edge.node2.xy.y + 1, edge.node1.xy.y)
568                    for y in r:
569                        xy = (edge.node1.xy.x, y)
570                        nodes = [x for x in self.nodes if x.xy == xy]
571                        if len(nodes) > 0:
572                            edge.skipped = 1
573            else:
574                if _dir == 'right':
575                    r = range(edge.node1.xy.x + 1, edge.node2.xy.x)
576                    for x in r:
577                        xy = (x, edge.node1.xy.y)
578                        nodes = [x for x in self.nodes if x.xy == xy]
579                        if len(nodes) > 0:
580                            edge.skipped = 1
581                elif _dir in ('left-down', 'down'):
582                    r = range(edge.node1.xy.y + 1, edge.node2.xy.y)
583                    for y in r:
584                        xy = (edge.node1.xy.x, y)
585                        nodes = [x for x in self.nodes if x.xy == xy]
586                        if len(nodes) > 0:
587                            edge.skipped = 1
588                elif _dir == 'right-down':
589                    if self.diagram.edge_layout == 'flowchart':
590                        r = range(edge.node1.xy.x, edge.node2.xy.x)
591                        for x in r:
592                            xy = (x + 1, edge.node1.xy.y)
593                            nodes = [x for x in self.nodes if x.xy == xy]
594                            if len(nodes) > 0:
595                                edge.skipped = 1
596
597                    r = range(edge.node1.xy.y + 1, edge.node2.xy.y)
598                    for y in r:
599                        xy = (edge.node2.xy.x, y)
600                        nodes = [x for x in self.nodes if x.xy == xy]
601                        if len(nodes) > 0:
602                            edge.skipped = 1
603
604
605class ScreenNodeBuilder:
606    @classmethod
607    def build(cls, tree, config=None, layout=True):
608        DiagramNode.clear()
609        DiagramEdge.clear()
610        NodeGroup.clear()
611        Diagram.clear()
612
613        return cls(tree, config, layout).run()
614
615    def __init__(self, tree, config, layout):
616        self.diagram = DiagramTreeBuilder().build(tree, config)
617        self.config = config
618        self.layout = layout
619
620    def run(self):
621        if self.layout:
622            DiagramLayoutManager(self.diagram).run()
623            self.diagram.fixiate(True)
624
625        EdgeLayoutManager(self.diagram).run()
626
627        return self.diagram
628
629
630class SeparateDiagramBuilder(ScreenNodeBuilder):
631    @property
632    def _groups(self):
633        # Store nodes and edges of subgroups
634        nodes = {self.diagram: self.diagram.nodes}
635        edges = {self.diagram: self.diagram.edges}
636        levels = {self.diagram: self.diagram.level}
637        for group in self.diagram.traverse_groups():
638            nodes[group] = group.nodes
639            edges[group] = group.edges
640            levels[group] = group.level
641
642        groups = {}
643        orders = {}
644        for node in self.diagram.traverse_nodes():
645            groups[node] = node.group
646            orders[node] = node.order
647
648        for group in self.diagram.traverse_groups():
649            yield group
650
651            # Restore nodes, groups and edges
652            for g in nodes:
653                g.nodes = nodes[g]
654                g.edges = edges[g]
655                g.level = levels[g]
656
657            for n in groups:
658                n.group = groups[n]
659                n.order = orders[n]
660                n.xy = XY(0, 0)
661                n.colwidth = 1
662                n.colheight = 1
663                n.separated = False
664
665            for edge in DiagramEdge.find_all():
666                edge.skipped = False
667                edge.crosspoints = []
668
669        yield self.diagram
670
671    def _filter_edges(self, edges, parent, level):
672        filtered = {}
673        for e in edges:
674            if e.node1.group.is_parent(parent):
675                if e.node1.group.level > level:
676                    e = e.duplicate()
677                    if isinstance(e.node1, NodeGroup):
678                        e.node1 = e.node1.parent(level + 1)
679                    else:
680                        e.node1 = e.node1.group.parent(level + 1)
681            else:
682                continue
683
684            if e.node2.group.is_parent(parent):
685                if e.node2.group.level > level:
686                    e = e.duplicate()
687                    if isinstance(e.node2, NodeGroup):
688                        e.node2 = e.node2.parent(level + 1)
689                    else:
690                        e.node2 = e.node2.group.parent(level + 1)
691            else:
692                continue
693
694            filtered[(e.node1, e.node2)] = e
695
696        return filtered.values()
697
698    def run(self):
699        for i, group in enumerate(self._groups):
700            base = self.diagram.duplicate()
701            base.level = group.level - 1
702
703            # bind edges on base diagram (outer the group)
704            edges = (DiagramEdge.find(None, group) +
705                     DiagramEdge.find(group, None))
706            base.edges = self._filter_edges(edges, self.diagram, group.level)
707
708            # bind edges on target group (inner the group)
709            subgroups = group.traverse_groups()
710            edges = sum([g.edges for g in subgroups], group.edges)
711            group.edges = []
712            for e in self._filter_edges(edges, group, group.level):
713                if isinstance(e.node1, NodeGroup) and e.node1 == e.node2:
714                    pass
715                else:
716                    group.edges.append(e)
717
718            # clear subgroups in the group
719            for g in group.nodes:
720                if isinstance(g, NodeGroup):
721                    g.nodes = []
722                    g.edges = []
723                    g.separated = True
724
725            # pick up nodes to base diagram
726            nodes1 = [e.node1 for e in DiagramEdge.find(None, group)]
727            nodes1.sort(key=lambda x: x.order)
728            nodes2 = [e.node2 for e in DiagramEdge.find(group, None)]
729            nodes2.sort(key=lambda x: x.order)
730
731            nodes = nodes1 + [group] + nodes2
732            for i, n in enumerate(nodes):
733                n.order = i
734                if n not in base.nodes:
735                    base.nodes.append(n)
736                    n.group = base
737
738            if isinstance(group, Diagram):
739                base = group
740
741            DiagramLayoutManager(base).run()
742            base.fixiate(True)
743            EdgeLayoutManager(base).run()
744
745            yield base
746