1# -*- coding: utf-8 -*-
2
3# Copyright (c) 2004 - 2021 Detlev Offenbach <detlev@die-offenbachs.de>
4#
5
6"""
7Module implementing a graphics item for an association between two items.
8"""
9
10import enum
11
12from PyQt5.QtCore import QPointF, QRectF, QLineF
13from PyQt5.QtWidgets import QGraphicsItem
14
15from E5Graphics.E5ArrowItem import E5ArrowItem, E5ArrowType
16
17import Utilities
18
19
20class AssociationType(enum.Enum):
21    """
22    Class defining the association types.
23    """
24    NORMAL = 0
25    GENERALISATION = 1
26    IMPORTS = 2
27
28
29class AssociationPointRegion(enum.Enum):
30    """
31    Class defining the regions for an association end point.
32    """
33    NO_REGION = 0
34    WEST = 1
35    NORTH = 2
36    EAST = 3
37    SOUTH = 4
38    NORTH_WEST = 5
39    NORTH_EAST = 6
40    SOUTH_EAST = 7
41    SOUTH_WEST = 8
42    CENTER = 9
43
44
45class AssociationItem(E5ArrowItem):
46    """
47    Class implementing a graphics item for an association between two items.
48
49    The association is drawn as an arrow starting at the first items and
50    ending at the second.
51    """
52    def __init__(self, itemA, itemB, assocType=AssociationType.NORMAL,
53                 topToBottom=False, colors=None, parent=None):
54        """
55        Constructor
56
57        @param itemA first widget of the association
58        @type UMLItem
59        @param itemB second widget of the association
60        @type UMLItem
61        @param assocType type of the association
62        @type AssociationType
63        @param topToBottom flag indicating to draw the association
64            from item A top to item B bottom
65        @type bool
66        @param colors tuple containing the foreground and background colors
67        @type tuple of (QColor, QColor)
68        @param parent reference to the parent object
69        @type QGraphicsItem
70        """
71        if assocType in (AssociationType.NORMAL, AssociationType.IMPORTS):
72            arrowType = E5ArrowType.NORMAL
73            arrowFilled = True
74        elif assocType == AssociationType.GENERALISATION:
75            arrowType = E5ArrowType.WIDE
76            arrowFilled = False
77
78        E5ArrowItem.__init__(self, QPointF(0, 0), QPointF(100, 100),
79                             arrowFilled, arrowType, colors, parent)
80
81        self.setFlag(QGraphicsItem.GraphicsItemFlag.ItemIsMovable, False)
82        self.setFlag(QGraphicsItem.GraphicsItemFlag.ItemIsSelectable, False)
83
84        if topToBottom:
85            self.calculateEndingPoints = (
86                self.__calculateEndingPoints_topToBottom
87            )
88        else:
89            #- self.calculateEndingPoints = self.__calculateEndingPoints_center
90            self.calculateEndingPoints = self.__calculateEndingPoints_rectangle
91
92        self.itemA = itemA
93        self.itemB = itemB
94        self.assocType = assocType
95        self.topToBottom = topToBottom
96
97        self.regionA = AssociationPointRegion.NO_REGION
98        self.regionB = AssociationPointRegion.NO_REGION
99
100        self.calculateEndingPoints()
101
102        self.itemA.addAssociation(self)
103        self.itemB.addAssociation(self)
104
105    def __mapRectFromItem(self, item):
106        """
107        Private method to map item's rectangle to this item's coordinate
108        system.
109
110        @param item reference to the item to be mapped
111        @type QGraphicsRectItem
112        @return item's rectangle in local coordinates
113        @rtype QRectF
114        """
115        rect = item.rect()
116        tl = self.mapFromItem(item, rect.topLeft())
117        return QRectF(tl.x(), tl.y(), rect.width(), rect.height())
118
119    def __calculateEndingPoints_topToBottom(self):
120        """
121        Private method to calculate the ending points of the association item.
122
123        The ending points are calculated from the top center of the lower item
124        to the bottom center of the upper item.
125        """
126        if self.itemA is None or self.itemB is None:
127            return
128
129        self.prepareGeometryChange()
130
131        rectA = self.__mapRectFromItem(self.itemA)
132        rectB = self.__mapRectFromItem(self.itemB)
133        midA = QPointF(rectA.x() + rectA.width() / 2.0,
134                       rectA.y() + rectA.height() / 2.0)
135        midB = QPointF(rectB.x() + rectB.width() / 2.0,
136                       rectB.y() + rectB.height() / 2.0)
137        if midA.y() > midB.y():
138            startP = QPointF(rectA.x() + rectA.width() / 2.0, rectA.y())
139            endP = QPointF(rectB.x() + rectB.width() / 2.0,
140                           rectB.y() + rectB.height())
141        else:
142            startP = QPointF(rectA.x() + rectA.width() / 2.0,
143                             rectA.y() + rectA.height())
144            endP = QPointF(rectB.x() + rectB.width() / 2.0, rectB.y())
145        self.setPoints(startP.x(), startP.y(), endP.x(), endP.y())
146
147    def __calculateEndingPoints_center(self):
148        """
149        Private method to calculate the ending points of the association item.
150
151        The ending points are calculated from the centers of the
152        two associated items.
153        """
154        if self.itemA is None or self.itemB is None:
155            return
156
157        self.prepareGeometryChange()
158
159        rectA = self.__mapRectFromItem(self.itemA)
160        rectB = self.__mapRectFromItem(self.itemB)
161        midA = QPointF(rectA.x() + rectA.width() / 2.0,
162                       rectA.y() + rectA.height() / 2.0)
163        midB = QPointF(rectB.x() + rectB.width() / 2.0,
164                       rectB.y() + rectB.height() / 2.0)
165        startP = self.__findRectIntersectionPoint(self.itemA, midA, midB)
166        endP = self.__findRectIntersectionPoint(self.itemB, midB, midA)
167
168        if (
169            startP.x() != -1 and
170            startP.y() != -1 and
171            endP.x() != -1 and
172            endP.y() != -1
173        ):
174            self.setPoints(startP.x(), startP.y(), endP.x(), endP.y())
175
176    def __calculateEndingPoints_rectangle(self):
177        r"""
178        Private method to calculate the ending points of the association item.
179
180        The ending points are calculated by the following method.
181
182        For each item the diagram is divided in four Regions by its diagonals
183        as indicated below
184        <pre>
185            +------------------------------+
186            |        \  Region 2  /        |
187            |         \          /         |
188            |          |--------|          |
189            |          | \    / |          |
190            |          |  \  /  |          |
191            |          |   \/   |          |
192            | Region 1 |   /\   | Region 3 |
193            |          |  /  \  |          |
194            |          | /    \ |          |
195            |          |--------|          |
196            |         /          \         |
197            |        /  Region 4  \        |
198            +------------------------------+
199        </pre>
200
201        Each diagonal is defined by two corners of the bounding rectangle.
202
203        To calculate the start point  we have to find out in which
204        region (defined by itemA's diagonals) is itemB's TopLeft corner
205        (lets call it region M). After that the start point will be
206        the middle point of rectangle's side contained in region M.
207
208        To calculate the end point we repeat the above but in the opposite
209        direction (from itemB to itemA)
210        """
211        if self.itemA is None or self.itemB is None:
212            return
213
214        self.prepareGeometryChange()
215
216        rectA = self.__mapRectFromItem(self.itemA)
217        rectB = self.__mapRectFromItem(self.itemB)
218
219        xA = rectA.x() + rectA.width() / 2.0
220        yA = rectA.y() + rectA.height() / 2.0
221        xB = rectB.x() + rectB.width() / 2.0
222        yB = rectB.y() + rectB.height() / 2.0
223
224        # find itemA region
225        rc = QRectF(xA, yA, rectA.width(), rectA.height())
226        self.regionA = self.__findPointRegion(rc, xB, yB)
227        # move some regions to the standard ones
228        if self.regionA == AssociationPointRegion.NORTH_WEST:
229            self.regionA = AssociationPointRegion.NORTH
230        elif self.regionA == AssociationPointRegion.NORTH_EAST:
231            self.regionA = AssociationPointRegion.EAST
232        elif self.regionA == AssociationPointRegion.SOUTH_EAST:
233            self.regionA = AssociationPointRegion.SOUTH
234        elif self.regionA in (
235            AssociationPointRegion.SOUTH_WEST,
236            AssociationPointRegion.CENTER
237        ):
238            self.regionA = AssociationPointRegion.WEST
239
240        self.__updateEndPoint(self.regionA, True)
241
242        # now do the same for itemB
243        rc = QRectF(xB, yB, rectB.width(), rectB.height())
244        self.regionB = self.__findPointRegion(rc, xA, yA)
245        # move some regions to the standard ones
246        if self.regionB == AssociationPointRegion.NORTH_WEST:
247            self.regionB = AssociationPointRegion.NORTH
248        elif self.regionB == AssociationPointRegion.NORTH_EAST:
249            self.regionB = AssociationPointRegion.EAST
250        elif self.regionB == AssociationPointRegion.SOUTH_EAST:
251            self.regionB = AssociationPointRegion.SOUTH
252        elif self.regionB in (
253            AssociationPointRegion.SOUTH_WEST,
254            AssociationPointRegion.CENTER
255        ):
256            self.regionB = AssociationPointRegion.WEST
257
258        self.__updateEndPoint(self.regionB, False)
259
260    def __findPointRegion(self, rect, posX, posY):
261        """
262        Private method to find out, which region of rectangle rect contains
263        the point (PosX, PosY) and returns the region number.
264
265        @param rect rectangle to calculate the region for
266        @type QRectF
267        @param posX x position of point
268        @type float
269        @param posY y position of point
270        @type float
271        @return the calculated region number<br />
272            West = Region 1<br />
273            North = Region 2<br />
274            East = Region 3<br />
275            South = Region 4<br />
276            NorthWest = On diagonal 2 between Region 1 and 2<br />
277            NorthEast = On diagonal 1 between Region 2 and 3<br />
278            SouthEast = On diagonal 2 between Region 3 and 4<br />
279            SouthWest = On diagonal 1 between Region4 and 1<br />
280            Center = On diagonal 1 and On diagonal 2 (the center)<br />
281        @rtype AssociationPointRegion
282        """
283        w = rect.width()
284        h = rect.height()
285        x = rect.x()
286        y = rect.y()
287        slope2 = w / h
288        slope1 = -slope2
289        b1 = x + w / 2.0 - y * slope1
290        b2 = x + w / 2.0 - y * slope2
291
292        eval1 = slope1 * posY + b1
293        eval2 = slope2 * posY + b2
294
295        result = AssociationPointRegion.NO_REGION
296
297        # inside region 1
298        if eval1 > posX and eval2 > posX:
299            result = AssociationPointRegion.WEST
300
301        #inside region 2
302        elif eval1 > posX and eval2 < posX:
303            result = AssociationPointRegion.NORTH
304
305        # inside region 3
306        elif eval1 < posX and eval2 < posX:
307            result = AssociationPointRegion.EAST
308
309        # inside region 4
310        elif eval1 < posX and eval2 > posX:
311            result = AssociationPointRegion.SOUTH
312
313        # inside region 5
314        elif eval1 == posX and eval2 < posX:
315            result = AssociationPointRegion.NORTH_WEST
316
317        # inside region 6
318        elif eval1 < posX and eval2 == posX:
319            result = AssociationPointRegion.NORTH_EAST
320
321        # inside region 7
322        elif eval1 == posX and eval2 > posX:
323            result = AssociationPointRegion.SOUTH_EAST
324
325        # inside region 8
326        elif eval1 > posX and eval2 == posX:
327            result = AssociationPointRegion.SOUTH_WEST
328
329        # inside region 9
330        elif eval1 == posX and eval2 == posX:
331            result = AssociationPointRegion.CENTER
332
333        return result
334
335    def __updateEndPoint(self, region, isWidgetA):
336        """
337        Private method to update an endpoint.
338
339        @param region the region for the endpoint
340        @type AssociationPointRegion
341        @param isWidgetA flag indicating update for itemA is done
342        @type bool
343        """
344        if region == AssociationPointRegion.NO_REGION:
345            return
346
347        rect = (
348            self.__mapRectFromItem(self.itemA)
349            if isWidgetA else
350            self.__mapRectFromItem(self.itemB)
351        )
352        x = rect.x()
353        y = rect.y()
354        ww = rect.width()
355        wh = rect.height()
356        ch = wh / 2.0
357        cw = ww / 2.0
358
359        if region == AssociationPointRegion.WEST:
360            px = x
361            py = y + ch
362        elif region == AssociationPointRegion.NORTH:
363            px = x + cw
364            py = y
365        elif region == AssociationPointRegion.EAST:
366            px = x + ww
367            py = y + ch
368        elif region in (
369            AssociationPointRegion.SOUTH,
370            AssociationPointRegion.CENTER
371        ):
372            px = x + cw
373            py = y + wh
374
375        if isWidgetA:
376            self.setStartPoint(px, py)
377        else:
378            self.setEndPoint(px, py)
379
380    def __findRectIntersectionPoint(self, item, p1, p2):
381        """
382        Private method to find the intersection point of a line with a
383        rectangle.
384
385        @param item item to check against
386        @type UMLItem
387        @param p1 first point of the line
388        @type QPointF
389        @param p2 second point of the line
390        @type QPointF
391        @return the intersection point
392        @rtype QPointF
393        """
394        rect = self.__mapRectFromItem(item)
395        lines = [
396            QLineF(rect.topLeft(), rect.topRight()),
397            QLineF(rect.topLeft(), rect.bottomLeft()),
398            QLineF(rect.bottomRight(), rect.bottomLeft()),
399            QLineF(rect.bottomRight(), rect.topRight())
400        ]
401        intersectLine = QLineF(p1, p2)
402        intersectPoint = QPointF(0, 0)
403        for line in lines:
404            if (
405                intersectLine.intersect(line, intersectPoint) ==
406                QLineF.IntersectType.BoundedIntersection
407            ):
408                return intersectPoint
409        return QPointF(-1.0, -1.0)
410
411    def __findIntersection(self, p1, p2, p3, p4):
412        """
413        Private method to calculate the intersection point of two lines.
414
415        The first line is determined by the points p1 and p2, the second
416        line by p3 and p4. If the intersection point is not contained in
417        the segment p1p2, then it returns (-1.0, -1.0).
418
419        For the function's internal calculations remember:<br />
420        QT coordinates start with the point (0,0) as the topleft corner
421        and x-values increase from left to right and y-values increase
422        from top to bottom; it means the visible area is quadrant I in
423        the regular XY coordinate system
424
425        <pre>
426            Quadrant II     |   Quadrant I
427           -----------------|-----------------
428            Quadrant III    |   Quadrant IV
429        </pre>
430
431        In order for the linear function calculations to work in this method
432        we must switch x and y values (x values become y values and viceversa)
433
434        @param p1 first point of first line
435        @type QPointF
436        @param p2 second point of first line
437        @type QPointF
438        @param p3 first point of second line
439        @type QPointF
440        @param p4 second point of second line
441        @type QPointF
442        @return the intersection point
443        @rtype QPointF
444        """
445        x1 = p1.y()
446        y1 = p1.x()
447        x2 = p2.y()
448        y2 = p2.x()
449        x3 = p3.y()
450        y3 = p3.x()
451        x4 = p4.y()
452        y4 = p4.x()
453
454        # line 1 is the line between (x1, y1) and (x2, y2)
455        # line 2 is the line between (x3, y3) and (x4, y4)
456        no_line1 = True    # it is false, if line 1 is a linear function
457        no_line2 = True    # it is false, if line 2 is a linear function
458        slope1 = 0.0
459        slope2 = 0.0
460        b1 = 0.0
461        b2 = 0.0
462
463        if x2 != x1:
464            slope1 = (y2 - y1) / (x2 - x1)
465            b1 = y1 - slope1 * x1
466            no_line1 = False
467        if x4 != x3:
468            slope2 = (y4 - y3) / (x4 - x3)
469            b2 = y3 - slope2 * x3
470            no_line2 = False
471
472        pt = QPointF()
473        # if either line is not a function
474        if no_line1 and no_line2:
475            # if the lines are not the same one
476            if x1 != x3:
477                return QPointF(-1.0, -1.0)
478            # if the lines are the same ones
479            if y3 <= y4:
480                if y3 <= y1 and y1 <= y4:
481                    return QPointF(y1, x1)
482                else:
483                    return QPointF(y2, x2)
484            else:
485                if y4 <= y1 and y1 <= y3:
486                    return QPointF(y1, x1)
487                else:
488                    return QPointF(y2, x2)
489        elif no_line1:
490            pt.setX(slope2 * x1 + b2)
491            pt.setY(x1)
492            if y1 >= y2:
493                if not (y2 <= pt.x() and pt.x() <= y1):
494                    pt.setX(-1.0)
495                    pt.setY(-1.0)
496            else:
497                if not (y1 <= pt.x() and pt.x() <= y2):
498                    pt.setX(-1.0)
499                    pt.setY(-1.0)
500            return pt
501        elif no_line2:
502            pt.setX(slope1 * x3 + b1)
503            pt.setY(x3)
504            if y3 >= y4:
505                if not (y4 <= pt.x() and pt.x() <= y3):
506                    pt.setX(-1.0)
507                    pt.setY(-1.0)
508            else:
509                if not (y3 <= pt.x() and pt.x() <= y4):
510                    pt.setX(-1.0)
511                    pt.setY(-1.0)
512            return pt
513
514        if slope1 == slope2:
515            pt.setX(-1.0)
516            pt.setY(-1.0)
517            return pt
518
519        pt.setY((b2 - b1) / (slope1 - slope2))
520        pt.setX(slope1 * pt.y() + b1)
521        # the intersection point must be inside the segment (x1, y1) (x2, y2)
522        if x2 >= x1 and y2 >= y1:
523            if not ((x1 <= pt.y() and pt.y() <= x2) and
524                    (y1 <= pt.x() and pt.x() <= y2)):
525                pt.setX(-1.0)
526                pt.setY(-1.0)
527        elif x2 < x1 and y2 >= y1:
528            if not ((x2 <= pt.y() and pt.y() <= x1) and
529                    (y1 <= pt.x() and pt.x() <= y2)):
530                pt.setX(-1.0)
531                pt.setY(-1.0)
532        elif x2 >= x1 and y2 < y1:
533            if not ((x1 <= pt.y() and pt.y() <= x2) and
534                    (y2 <= pt.x() and pt.x() <= y1)):
535                pt.setX(-1.0)
536                pt.setY(-1.0)
537        else:
538            if not ((x2 <= pt.y() and pt.y() <= x1) and
539                    (y2 <= pt.x() and pt.x() <= y1)):
540                pt.setX(-1.0)
541                pt.setY(-1.0)
542
543        return pt
544
545    def widgetMoved(self):
546        """
547        Public method to recalculate the association after a widget was moved.
548        """
549        self.calculateEndingPoints()
550
551    def unassociate(self):
552        """
553        Public method to unassociate from the widgets.
554        """
555        self.itemA.removeAssociation(self)
556        self.itemB.removeAssociation(self)
557
558    def buildAssociationItemDataString(self):
559        """
560        Public method to build a string to persist the specific item data.
561
562        This string should be built like "attribute=value" with pairs separated
563        by ", ". value must not contain ", " or newlines.
564
565        @return persistence data
566        @rtype str
567        """
568        entries = [
569            "src={0}".format(self.itemA.getId()),
570            "dst={0}".format(self.itemB.getId()),
571            "type={0}".format(self.assocType.value),
572            "topToBottom={0}".format(self.topToBottom)
573        ]
574        return ", ".join(entries)
575
576    @classmethod
577    def parseAssociationItemDataString(cls, data):
578        """
579        Class method to parse the given persistence data.
580
581        @param data persisted data to be parsed
582        @type str
583        @return tuple with the IDs of the source and destination items,
584            the association type and a flag indicating to associate from top
585            to bottom
586        @rtype tuple of (int, int, int, bool)
587        """
588        src = -1
589        dst = -1
590        assocType = AssociationType.NORMAL
591        topToBottom = False
592        for entry in data.split(", "):
593            if "=" in entry:
594                key, value = entry.split("=", 1)
595                if key == "src":
596                    src = int(value)
597                elif key == "dst":
598                    dst = int(value)
599                elif key == "type":
600                    assocType = AssociationType(int(value))
601                elif key == "topToBottom":
602                    topToBottom = Utilities.toBool(value)
603
604        return src, dst, assocType, topToBottom
605
606    def toDict(self):
607        """
608        Public method to collect data to be persisted.
609
610        @return dictionary containing data to be persisted
611        @rtype dict
612        """
613        return {
614            "src": self.itemA.getId(),
615            "dst": self.itemB.getId(),
616            "type": self.assocType.value,
617            "topToBottom": self.topToBottom,
618        }
619
620    @classmethod
621    def fromDict(cls, data, umlItems, colors=None):
622        """
623        Class method to create an association item from persisted data.
624
625        @param data dictionary containing the persisted data as generated
626            by toDict()
627        @type dict
628        @param umlItems list of UML items
629        @type list of UMLItem
630        @param colors tuple containing the foreground and background colors
631        @type tuple of (QColor, QColor)
632        @return created association item
633        @rtype AssociationItem
634        """
635        try:
636            return cls(umlItems[data["src"]],
637                       umlItems[data["dst"]],
638                       assocType=AssociationType(data["type"]),
639                       topToBottom=data["topToBottom"],
640                       colors=colors)
641        except (KeyError, ValueError):
642            return None
643