1# -*- coding: utf-8 -*-
2"""
3modelviz.py - DOT file generator for Django Models
4
5Based on:
6  Django model to DOT (Graphviz) converter
7  by Antonio Cavedoni <antonio@cavedoni.org>
8  Adapted to be used with django-extensions
9"""
10
11import datetime
12import os
13import re
14
15from django.apps import apps
16from django.db.models.fields.related import (
17    ForeignKey, ManyToManyField, OneToOneField, RelatedField,
18)
19from django.contrib.contenttypes.fields import GenericRelation
20from django.template import Context, Template, loader
21from django.utils.encoding import force_str
22from django.utils.safestring import mark_safe
23from django.utils.translation import activate as activate_language
24
25
26__version__ = "1.1"
27__license__ = "Python"
28__author__ = "Bas van Oostveen <v.oostveen@gmail.com>",
29__contributors__ = [
30    "Antonio Cavedoni <http://cavedoni.com/>"
31    "Stefano J. Attardi <http://attardi.org/>",
32    "limodou <http://www.donews.net/limodou/>",
33    "Carlo C8E Miron",
34    "Andre Campos <cahenan@gmail.com>",
35    "Justin Findlay <jfindlay@gmail.com>",
36    "Alexander Houben <alexander@houben.ch>",
37    "Joern Hees <gitdev@joernhees.de>",
38    "Kevin Cherepski <cherepski@gmail.com>",
39    "Jose Tomas Tocino <theom3ga@gmail.com>",
40    "Adam Dobrawy <naczelnik@jawnosc.tk>",
41    "Mikkel Munch Mortensen <https://www.detfalskested.dk/>",
42    "Andrzej Bistram <andrzej.bistram@gmail.com>",
43    "Daniel Lipsitt <danlipsitt@gmail.com>",
44]
45
46
47def parse_file_or_list(arg):
48    if not arg:
49        return []
50    if isinstance(arg, (list, tuple, set)):
51        return arg
52    if ',' not in arg and os.path.isfile(arg):
53        return [e.strip() for e in open(arg).readlines()]
54    return [e.strip() for e in arg.split(',')]
55
56
57class ModelGraph:
58    def __init__(self, app_labels, **kwargs):
59        self.graphs = []
60        self.cli_options = kwargs.get('cli_options', None)
61        self.disable_fields = kwargs.get('disable_fields', False)
62        self.disable_abstract_fields = kwargs.get('disable_abstract_fields', False)
63        self.include_models = parse_file_or_list(
64            kwargs.get('include_models', "")
65        )
66        self.all_applications = kwargs.get('all_applications', False)
67        self.use_subgraph = kwargs.get('group_models', False)
68        self.verbose_names = kwargs.get('verbose_names', False)
69        self.inheritance = kwargs.get('inheritance', True)
70        self.relations_as_fields = kwargs.get("relations_as_fields", True)
71        self.sort_fields = kwargs.get("sort_fields", True)
72        self.language = kwargs.get('language', None)
73        if self.language is not None:
74            activate_language(self.language)
75        self.exclude_columns = parse_file_or_list(
76            kwargs.get('exclude_columns', "")
77        )
78        self.exclude_models = parse_file_or_list(
79            kwargs.get('exclude_models', "")
80        )
81        self.hide_edge_labels = kwargs.get('hide_edge_labels', False)
82        self.arrow_shape = kwargs.get("arrow_shape")
83        if self.all_applications:
84            self.app_labels = [app.label for app in apps.get_app_configs()]
85        else:
86            self.app_labels = app_labels
87
88    def generate_graph_data(self):
89        self.process_apps()
90
91        nodes = []
92        for graph in self.graphs:
93            nodes.extend([e['name'] for e in graph['models']])
94
95        for graph in self.graphs:
96            for model in graph['models']:
97                for relation in model['relations']:
98                    if relation is not None:
99                        if relation['target'] in nodes:
100                            relation['needs_node'] = False
101
102    def get_graph_data(self, as_json=False):
103        now = datetime.datetime.now()
104        graph_data = {
105            'created_at': now.strftime("%Y-%m-%d %H:%M"),
106            'cli_options': self.cli_options,
107            'disable_fields': self.disable_fields,
108            'disable_abstract_fields': self.disable_abstract_fields,
109            'use_subgraph': self.use_subgraph,
110        }
111
112        if as_json:
113            # We need to remove the model and field class because it is not JSON serializable
114            graphs = [context.flatten() for context in self.graphs]
115            for context in graphs:
116                for model_data in context['models']:
117                    model_data.pop('model')
118                    for field_data in model_data['fields']:
119                        field_data.pop('field')
120            graph_data['graphs'] = graphs
121        else:
122            graph_data['graphs'] = self.graphs
123
124        return graph_data
125
126    def add_attributes(self, field, abstract_fields):
127        if self.verbose_names and field.verbose_name:
128            label = force_str(field.verbose_name)
129            if label.islower():
130                label = label.capitalize()
131        else:
132            label = field.name
133
134        t = type(field).__name__
135        if isinstance(field, (OneToOneField, ForeignKey)):
136            t += " ({0})".format(field.remote_field.field_name)
137        # TODO: ManyToManyField, GenericRelation
138
139        return {
140            'field': field,
141            'name': field.name,
142            'label': label,
143            'type': t,
144            'blank': field.blank,
145            'abstract': field in abstract_fields,
146            'relation': isinstance(field, RelatedField),
147            'primary_key': field.primary_key,
148        }
149
150    def add_relation(self, field, model, extras=""):
151        if self.verbose_names and field.verbose_name:
152            label = force_str(field.verbose_name)
153            if label.islower():
154                label = label.capitalize()
155        else:
156            label = field.name
157
158        # show related field name
159        if hasattr(field, 'related_query_name'):
160            related_query_name = field.related_query_name()
161            if self.verbose_names and related_query_name.islower():
162                related_query_name = related_query_name.replace('_', ' ').capitalize()
163            label = u'{} ({})'.format(label, force_str(related_query_name))
164        if self.hide_edge_labels:
165            label = ''
166
167        # handle self-relationships and lazy-relationships
168        if isinstance(field.remote_field.model, str):
169            if field.remote_field.model == 'self':
170                target_model = field.model
171            else:
172                if '.' in field.remote_field.model:
173                    app_label, model_name = field.remote_field.model.split('.', 1)
174                else:
175                    app_label = field.model._meta.app_label
176                    model_name = field.remote_field.model
177                target_model = apps.get_model(app_label, model_name)
178        else:
179            target_model = field.remote_field.model
180
181        _rel = self.get_relation_context(target_model, field, label, extras)
182
183        if _rel not in model['relations'] and self.use_model(_rel['target']):
184            return _rel
185
186    def get_abstract_models(self, appmodels):
187        abstract_models = []
188        for appmodel in appmodels:
189            abstract_models += [
190                abstract_model for abstract_model in appmodel.__bases__
191                if hasattr(abstract_model, '_meta') and abstract_model._meta.abstract
192            ]
193        abstract_models = list(set(abstract_models))  # remove duplicates
194        return abstract_models
195
196    def get_app_context(self, app):
197        return Context({
198            'name': '"%s"' % app.name,
199            'app_name': "%s" % app.name,
200            'cluster_app_name': "cluster_%s" % app.name.replace(".", "_"),
201            'models': []
202        })
203
204    def get_appmodel_attributes(self, appmodel):
205        if self.relations_as_fields:
206            attributes = [field for field in appmodel._meta.local_fields]
207        else:
208            # Find all the 'real' attributes. Relations are depicted as graph edges instead of attributes
209            attributes = [field for field in appmodel._meta.local_fields if not
210                          isinstance(field, RelatedField)]
211        return attributes
212
213    def get_appmodel_abstracts(self, appmodel):
214        return [
215            abstract_model.__name__ for abstract_model in appmodel.__bases__
216            if hasattr(abstract_model, '_meta') and abstract_model._meta.abstract
217        ]
218
219    def get_appmodel_context(self, appmodel, appmodel_abstracts):
220        context = {
221            'model': appmodel,
222            'app_name': appmodel.__module__.replace(".", "_"),
223            'name': appmodel.__name__,
224            'abstracts': appmodel_abstracts,
225            'fields': [],
226            'relations': []
227        }
228
229        if self.verbose_names and appmodel._meta.verbose_name:
230            context['label'] = force_str(appmodel._meta.verbose_name)
231        else:
232            context['label'] = context['name']
233
234        return context
235
236    def get_bases_abstract_fields(self, c):
237        _abstract_fields = []
238        for e in c.__bases__:
239            if hasattr(e, '_meta') and e._meta.abstract:
240                _abstract_fields.extend(e._meta.fields)
241                _abstract_fields.extend(self.get_bases_abstract_fields(e))
242        return _abstract_fields
243
244    def get_inheritance_context(self, appmodel, parent):
245        label = "multi-table"
246        if parent._meta.abstract:
247            label = "abstract"
248        if appmodel._meta.proxy:
249            label = "proxy"
250        label += r"\ninheritance"
251        if self.hide_edge_labels:
252            label = ''
253        return {
254            'target_app': parent.__module__.replace(".", "_"),
255            'target': parent.__name__,
256            'type': "inheritance",
257            'name': "inheritance",
258            'label': label,
259            'arrows': '[arrowhead=empty, arrowtail=none, dir=both]',
260            'needs_node': True,
261        }
262
263    def get_models(self, app):
264        appmodels = list(app.get_models())
265        return appmodels
266
267    def get_relation_context(self, target_model, field, label, extras):
268        return {
269            'target_app': target_model.__module__.replace('.', '_'),
270            'target': target_model.__name__,
271            'type': type(field).__name__,
272            'name': field.name,
273            'label': label,
274            'arrows': extras,
275            'needs_node': True
276        }
277
278    def process_attributes(self, field, model, pk, abstract_fields):
279        newmodel = model.copy()
280        if self.skip_field(field) or pk and field == pk:
281            return newmodel
282        newmodel['fields'].append(self.add_attributes(field, abstract_fields))
283        return newmodel
284
285    def process_apps(self):
286        for app_label in self.app_labels:
287            app = apps.get_app_config(app_label)
288            if not app:
289                continue
290            app_graph = self.get_app_context(app)
291            app_models = self.get_models(app)
292            abstract_models = self.get_abstract_models(app_models)
293            app_models = abstract_models + app_models
294
295            for appmodel in app_models:
296                if not self.use_model(appmodel._meta.object_name):
297                    continue
298                appmodel_abstracts = self.get_appmodel_abstracts(appmodel)
299                abstract_fields = self.get_bases_abstract_fields(appmodel)
300                model = self.get_appmodel_context(appmodel, appmodel_abstracts)
301                attributes = self.get_appmodel_attributes(appmodel)
302
303                # find primary key and print it first, ignoring implicit id if other pk exists
304                pk = appmodel._meta.pk
305                if pk and not appmodel._meta.abstract and pk in attributes:
306                    model['fields'].append(self.add_attributes(pk, abstract_fields))
307
308                for field in attributes:
309                    model = self.process_attributes(field, model, pk, abstract_fields)
310
311                if self.sort_fields:
312                    model = self.sort_model_fields(model)
313
314                for field in appmodel._meta.local_fields:
315                    model = self.process_local_fields(field, model, abstract_fields)
316
317                for field in appmodel._meta.local_many_to_many:
318                    model = self.process_local_many_to_many(field, model)
319
320                if self.inheritance:
321                    # add inheritance arrows
322                    for parent in appmodel.__bases__:
323                        model = self.process_parent(parent, appmodel, model)
324
325                app_graph['models'].append(model)
326            if app_graph['models']:
327                self.graphs.append(app_graph)
328
329    def process_local_fields(self, field, model, abstract_fields):
330        newmodel = model.copy()
331        if field.attname.endswith('_ptr_id') or field in abstract_fields or self.skip_field(field):
332            # excluding field redundant with inheritance relation
333            # excluding fields inherited from abstract classes. they too show as local_fields
334            return newmodel
335        if isinstance(field, OneToOneField):
336            relation = self.add_relation(
337                field, newmodel, '[arrowhead=none, arrowtail=none, dir=both]'
338            )
339        elif isinstance(field, ForeignKey):
340            relation = self.add_relation(
341                field,
342                newmodel,
343                '[arrowhead=none, arrowtail={}, dir=both]'.format(
344                    self.arrow_shape
345                ),
346            )
347        else:
348            relation = None
349        if relation is not None:
350            newmodel['relations'].append(relation)
351        return newmodel
352
353    def process_local_many_to_many(self, field, model):
354        newmodel = model.copy()
355        if self.skip_field(field):
356            return newmodel
357        relation = None
358        if isinstance(field, ManyToManyField):
359            if hasattr(field.remote_field.through, '_meta') and field.remote_field.through._meta.auto_created:
360                relation = self.add_relation(
361                    field,
362                    newmodel,
363                    '[arrowhead={} arrowtail={}, dir=both]'.format(
364                        self.arrow_shape, self.arrow_shape
365                    ),
366                )
367        elif isinstance(field, GenericRelation):
368            relation = self.add_relation(field, newmodel, mark_safe('[style="dotted", arrowhead=normal, arrowtail=normal, dir=both]'))
369        if relation is not None:
370            newmodel['relations'].append(relation)
371        return newmodel
372
373    def process_parent(self, parent, appmodel, model):
374        newmodel = model.copy()
375        if hasattr(parent, "_meta"):  # parent is a model
376            _rel = self.get_inheritance_context(appmodel, parent)
377            # TODO: seems as if abstract models aren't part of models.getModels, which is why they are printed by this without any attributes.
378            if _rel not in newmodel['relations'] and self.use_model(_rel['target']):
379                newmodel['relations'].append(_rel)
380        return newmodel
381
382    def sort_model_fields(self, model):
383        newmodel = model.copy()
384        newmodel['fields'] = sorted(newmodel['fields'], key=lambda field: (not field['primary_key'], not field['relation'], field['label']))
385        return newmodel
386
387    def use_model(self, model_name):
388        """
389        Decide whether to use a model, based on the model name and the lists of
390        models to exclude and include.
391        """
392        # Check against include list.
393        if self.include_models:
394            for model_pattern in self.include_models:
395                model_pattern = '^%s$' % model_pattern.replace('*', '.*')
396                if re.search(model_pattern, model_name):
397                    return True
398        # Check against exclude list.
399        if self.exclude_models:
400            for model_pattern in self.exclude_models:
401                model_pattern = '^%s$' % model_pattern.replace('*', '.*')
402                if re.search(model_pattern, model_name):
403                    return False
404        # Return `True` if `include_models` is falsey, otherwise return `False`.
405        return not self.include_models
406
407    def skip_field(self, field):
408        if self.exclude_columns:
409            if self.verbose_names and field.verbose_name:
410                if field.verbose_name in self.exclude_columns:
411                    return True
412            if field.name in self.exclude_columns:
413                return True
414        return False
415
416
417def generate_dot(graph_data, template='django_extensions/graph_models/digraph.dot'):
418    if isinstance(template, str):
419        template = loader.get_template(template)
420
421    if not isinstance(template, Template) and not (hasattr(template, 'template') and isinstance(template.template, Template)):
422        raise Exception("Default Django template loader isn't used. "
423                        "This can lead to the incorrect template rendering. "
424                        "Please, check the settings.")
425
426    c = Context(graph_data).flatten()
427    dot = template.render(c)
428
429    return dot
430
431
432def generate_graph_data(*args, **kwargs):
433    generator = ModelGraph(*args, **kwargs)
434    generator.generate_graph_data()
435    return generator.get_graph_data()
436
437
438def use_model(model, include_models, exclude_models):
439    generator = ModelGraph([], include_models=include_models, exclude_models=exclude_models)
440    return generator.use_model(model)
441