1# -*- coding: utf-8 -*-
2
3# Copyright (c) 2007 - 2021 Detlev Offenbach <detlev@die-offenbachs.de>
4#
5
6"""
7Module implementing a dialog showing a UML like class diagram.
8"""
9
10from itertools import zip_longest
11import os
12
13from PyQt5.QtWidgets import QGraphicsTextItem
14
15import Utilities
16import Preferences
17
18from .UMLDiagramBuilder import UMLDiagramBuilder
19
20
21class UMLClassDiagramBuilder(UMLDiagramBuilder):
22    """
23    Class implementing a builder for UML like class diagrams.
24    """
25    def __init__(self, dialog, view, project, file, noAttrs=False):
26        """
27        Constructor
28
29        @param dialog reference to the UML dialog
30        @type UMLDialog
31        @param view reference to the view object
32        @type UMLGraphicsView
33        @param project reference to the project object
34        @type Project
35        @param file file name of a python module to be shown
36        @type str
37        @param noAttrs flag indicating, that no attributes should be shown
38        @type bool
39        """
40        super().__init__(dialog, view, project)
41        self.setObjectName("UMLClassDiagramBuilder")
42
43        self.file = file
44        self.noAttrs = noAttrs
45
46        self.__relFile = (
47            self.project.getRelativePath(self.file)
48            if self.project.isProjectSource(self.file) else
49            ""
50        )
51
52    def initialize(self):
53        """
54        Public method to initialize the object.
55        """
56        pname = self.project.getProjectName()
57        name = (
58            self.tr("Class Diagram {0}: {1}").format(
59                pname, self.project.getRelativePath(self.file))
60            if pname and self.project.isProjectSource(self.file) else
61            self.tr("Class Diagram: {0}").format(self.file)
62        )
63        self.umlView.setDiagramName(name)
64
65    def __getCurrentShape(self, name):
66        """
67        Private method to get the named shape.
68
69        @param name name of the shape
70        @type str
71        @return shape
72        @rtype QGraphicsItem
73        """
74        return self.allClasses.get(name)
75
76    def buildDiagram(self):
77        """
78        Public method to build the class shapes of the class diagram.
79
80        The algorithm is borrowed from Boa Constructor.
81        """
82        import Utilities.ModuleParser
83
84        self.allClasses = {}
85        self.allModules = {}
86
87        try:
88            extensions = (
89                Preferences.getPython("Python3Extensions") +
90                ['.rb']
91            )
92            module = Utilities.ModuleParser.readModule(
93                self.file, extensions=extensions, caching=False)
94        except ImportError:
95            ct = QGraphicsTextItem(None)
96            ct.setHtml(self.buildErrorMessage(
97                self.tr("The module <b>'{0}'</b> could not be found.")
98                    .format(self.file)
99            ))
100            self.scene.addItem(ct)
101            return
102
103        if self.file not in self.allModules:
104            self.allModules[self.file] = []
105
106        routes = []
107        nodes = []
108        todo = [module.createHierarchy()]
109        classesFound = False
110        while todo:
111            hierarchy = todo[0]
112            for className in hierarchy:
113                classesFound = True
114                cw = self.__getCurrentShape(className)
115                if not cw and className.find('.') >= 0:
116                    cw = self.__getCurrentShape(className.split('.')[-1])
117                    if cw:
118                        self.allClasses[className] = cw
119                        if className not in self.allModules[self.file]:
120                            self.allModules[self.file].append(className)
121                if cw and cw.noAttrs != self.noAttrs:
122                    cw = None
123                if cw and not (cw.external and
124                               (className in module.classes or
125                                className in module.modules)):
126                    if cw.scene() != self.scene:
127                        self.scene.addItem(cw)
128                        cw.setPos(10, 10)
129                        if className not in nodes:
130                            nodes.append(className)
131                else:
132                    if className in module.classes:
133                        # this is a local class (defined in this module)
134                        self.__addLocalClass(
135                            className, module.classes[className], 0, 0)
136                    elif className in module.modules:
137                        # this is a local module (defined in this module)
138                        self.__addLocalClass(
139                            className, module.modules[className], 0, 0, True)
140                    else:
141                        self.__addExternalClass(className, 0, 0)
142                    nodes.append(className)
143
144                if hierarchy.get(className):
145                    todo.append(hierarchy.get(className))
146                    children = list(hierarchy.get(className).keys())
147                    for child in children:
148                        if (className, child) not in routes:
149                            routes.append((className, child))
150
151            del todo[0]
152
153        if classesFound:
154            self.__arrangeClasses(nodes, routes[:])
155            self.__createAssociations(routes)
156            self.umlView.autoAdjustSceneSize(limit=True)
157        else:
158            ct = QGraphicsTextItem(None)
159            ct.setHtml(self.buildErrorMessage(
160                self.tr("The module <b>'{0}'</b> does not contain any"
161                        " classes.").format(self.file)
162            ))
163            self.scene.addItem(ct)
164
165    def __arrangeClasses(self, nodes, routes, whiteSpaceFactor=1.2):
166        """
167        Private method to arrange the shapes on the canvas.
168
169        The algorithm is borrowed from Boa Constructor.
170
171        @param nodes list of nodes to arrange
172        @type list of str
173        @param routes list of routes
174        @type list of tuple of (str, str)
175        @param whiteSpaceFactor factor to increase whitespace between
176            items
177        @type float
178        """
179        from . import GraphicsUtilities
180        generations = GraphicsUtilities.sort(nodes, routes)
181
182        # calculate width and height of all elements
183        sizes = []
184        for generation in generations:
185            sizes.append([])
186            for child in generation:
187                sizes[-1].append(
188                    self.__getCurrentShape(child).sceneBoundingRect())
189
190        # calculate total width and total height
191        width = 0
192        height = 0
193        widths = []
194        heights = []
195        for generation in sizes:
196            currentWidth = 0
197            currentHeight = 0
198
199            for rect in generation:
200                if rect.bottom() > currentHeight:
201                    currentHeight = rect.bottom()
202                currentWidth += rect.right()
203
204            # update totals
205            if currentWidth > width:
206                width = currentWidth
207            height += currentHeight
208
209            # store generation info
210            widths.append(currentWidth)
211            heights.append(currentHeight)
212
213        # add in some whitespace
214        width *= whiteSpaceFactor
215        height = height * whiteSpaceFactor - 20
216        verticalWhiteSpace = 40.0
217
218        sceneRect = self.umlView.sceneRect()
219        width += 50.0
220        height += 50.0
221        swidth = sceneRect.width() if width < sceneRect.width() else width
222        sheight = sceneRect.height() if height < sceneRect.height() else height
223        self.umlView.setSceneSize(swidth, sheight)
224
225        # distribute each generation across the width and the
226        # generations across height
227        y = 10.0
228        for currentWidth, currentHeight, generation in (
229                zip_longest(widths, heights, generations)
230        ):
231            x = 10.0
232            # whiteSpace is the space between any two elements
233            whiteSpace = (
234                (width - currentWidth - 20) /
235                (len(generation) - 1.0 or 2.0)
236            )
237            for className in generation:
238                cw = self.__getCurrentShape(className)
239                cw.setPos(x, y)
240                rect = cw.sceneBoundingRect()
241                x = x + rect.width() + whiteSpace
242            y = y + currentHeight + verticalWhiteSpace
243
244    def __addLocalClass(self, className, _class, x, y, isRbModule=False):
245        """
246        Private method to add a class defined in the module.
247
248        @param className name of the class to be as a dictionary key
249        @type str
250        @param _class class to be shown
251        @type ModuleParser.Class
252        @param x x-coordinate
253        @type float
254        @param y y-coordinate
255        @type float
256        @param isRbModule flag indicating a Ruby module
257        @type bool
258        """
259        from .ClassItem import ClassItem, ClassModel
260        name = _class.name
261        if isRbModule:
262            name = "{0} (Module)".format(name)
263        cl = ClassModel(
264            name,
265            sorted(_class.methods.keys())[:],
266            sorted(_class.attributes.keys())[:],
267            sorted(_class.globals.keys())[:]
268        )
269        cw = ClassItem(cl, False, x, y, noAttrs=self.noAttrs, scene=self.scene,
270                       colors=self.umlView.getDrawingColors())
271        cw.setId(self.umlView.getItemId())
272        self.allClasses[className] = cw
273        if _class.name not in self.allModules[self.file]:
274            self.allModules[self.file].append(_class.name)
275
276    def __addExternalClass(self, _class, x, y):
277        """
278        Private method to add a class defined outside the module.
279
280        If the canvas is too small to take the shape, it
281        is enlarged.
282
283        @param _class class to be shown
284        @type ModuleParser.Class
285        @param x x-coordinate
286        @type float
287        @param y y-coordinate
288        @type float
289        """
290        from .ClassItem import ClassItem, ClassModel
291        cl = ClassModel(_class)
292        cw = ClassItem(cl, True, x, y, noAttrs=self.noAttrs, scene=self.scene,
293                       colors=self.umlView.getDrawingColors())
294        cw.setId(self.umlView.getItemId())
295        self.allClasses[_class] = cw
296        if _class not in self.allModules[self.file]:
297            self.allModules[self.file].append(_class)
298
299    def __createAssociations(self, routes):
300        """
301        Private method to generate the associations between the class shapes.
302
303        @param routes list of relationsships
304        @type list of tuple of (str, str)
305        """
306        from .AssociationItem import AssociationItem, AssociationType
307        for route in routes:
308            if len(route) > 1:
309                assoc = AssociationItem(
310                    self.__getCurrentShape(route[1]),
311                    self.__getCurrentShape(route[0]),
312                    AssociationType.GENERALISATION,
313                    topToBottom=True,
314                    colors=self.umlView.getDrawingColors())
315                self.scene.addItem(assoc)
316
317    def getPersistenceData(self):
318        """
319        Public method to get a string for data to be persisted.
320
321        @return persisted data string
322        @rtype str
323        """
324        return "file={0}, no_attributes={1}".format(self.file, self.noAttrs)
325
326    def parsePersistenceData(self, version, data):
327        """
328        Public method to parse persisted data.
329
330        @param version version of the data
331        @type str
332        @param data persisted data to be parsed
333        @type str
334        @return flag indicating success
335        @rtype bool
336        """
337        parts = data.split(", ")
338        if (
339            len(parts) != 2 or
340            not parts[0].startswith("file=") or
341            not parts[1].startswith("no_attributes=")
342        ):
343            return False
344
345        self.file = parts[0].split("=", 1)[1].strip()
346        self.noAttrs = Utilities.toBool(parts[1].split("=", 1)[1].strip())
347
348        self.initialize()
349
350        return True
351
352    def toDict(self):
353        """
354        Public method to collect data to be persisted.
355
356        @return dictionary containing data to be persisted
357        @rtype dict
358        """
359        data = {
360            "project_name": self.project.getProjectName(),
361            "no_attributes": self.noAttrs,
362        }
363
364        data["file"] = (
365            Utilities.fromNativeSeparators(self.__relFile)
366            if self.__relFile else
367            Utilities.fromNativeSeparators(self.file)
368        )
369
370        return data
371
372    def fromDict(self, version, data):
373        """
374        Public method to populate the class with data persisted by 'toDict()'.
375
376        @param version version of the data
377        @type str
378        @param data dictionary containing the persisted data
379        @type dict
380        @return tuple containing a flag indicating success and an info
381            message in case the diagram belongs to a different project
382        @rtype tuple of (bool, str)
383        """
384        try:
385            self.noAttrs = data["no_attributes"]
386
387            file = Utilities.toNativeSeparators(data["file"])
388            if os.path.isabs(file):
389                self.file = file
390                self.__relFile = ""
391            else:
392                # relative file paths indicate a project file
393                if data["project_name"] != self.project.getProjectName():
394                    msg = self.tr(
395                        "<p>The diagram belongs to project <b>{0}</b>."
396                        " Please open it and try again.</p>"
397                    ).format(data["project_name"])
398                    return False, msg
399
400                self.__relFile = file
401                self.file = self.project.getAbsolutePath(file)
402        except KeyError:
403            return False, ""
404
405        self.initialize()
406
407        return True, ""
408