1from transitions import Transition
2from transitions.extensions.markup import MarkupMachine
3from transitions.core import listify
4
5import logging
6from functools import partial
7import copy
8
9_LOGGER = logging.getLogger(__name__)
10_LOGGER.addHandler(logging.NullHandler())
11
12# this is a workaround for dill issues when partials and super is used in conjunction
13# without it, Python 3.0 - 3.3 will not support pickling
14# https://github.com/pytransitions/transitions/issues/236
15_super = super
16
17
18class TransitionGraphSupport(Transition):
19    """ Transition used in conjunction with (Nested)Graphs to update graphs whenever a transition is
20        conducted.
21    """
22
23    def __init__(self, *args, **kwargs):
24        label = kwargs.pop('label', None)
25        _super(TransitionGraphSupport, self).__init__(*args, **kwargs)
26        if label:
27            self.label = label
28
29    def _change_state(self, event_data):
30        graph = event_data.machine.model_graphs[event_data.model]
31        graph.reset_styling()
32        graph.set_previous_transition(self.source, self.dest, event_data.event.name)
33        _super(TransitionGraphSupport, self)._change_state(event_data)  # pylint: disable=protected-access
34        graph = event_data.machine.model_graphs[event_data.model]  # graph might have changed during change_event
35        for state in _flatten(listify(getattr(event_data.model, event_data.machine.model_attribute))):
36            graph.set_node_style(self.dest if hasattr(state, 'name') else state, 'active')
37
38
39class GraphMachine(MarkupMachine):
40    """ Extends transitions.core.Machine with graph support.
41        Is also used as a mixin for HierarchicalMachine.
42        Attributes:
43            _pickle_blacklist (list): Objects that should not/do not need to be pickled.
44            transition_cls (cls): TransitionGraphSupport
45    """
46
47    _pickle_blacklist = ['model_graphs']
48    transition_cls = TransitionGraphSupport
49
50    machine_attributes = {
51        'directed': 'true',
52        'strict': 'false',
53        'rankdir': 'LR',
54    }
55
56    hierarchical_machine_attributes = {
57        'rankdir': 'TB',
58        'rank': 'source',
59        'nodesep': '1.5',
60        'compound': 'true'
61    }
62
63    style_attributes = {
64        'node': {
65            '': {},
66            'default': {
67                'style': 'rounded, filled',
68                'shape': 'rectangle',
69                'fillcolor': 'white',
70                'color': 'black',
71                'peripheries': '1'
72            },
73            'inactive': {
74                'fillcolor': 'white',
75                'color': 'black',
76                'peripheries': '1'
77            },
78            'parallel': {
79                'shape': 'rectangle',
80                'color': 'black',
81                'fillcolor': 'white',
82                'style': 'dashed, rounded, filled',
83                'peripheries': '1'
84            },
85            'active': {
86                'color': 'red',
87                'fillcolor': 'darksalmon',
88                'peripheries': '2'
89            },
90            'previous': {
91                'color': 'blue',
92                'fillcolor': 'azure2',
93                'peripheries': '1'
94            }
95        },
96        'edge': {
97            '': {},
98            'default': {
99                'color': 'black'
100            },
101            'previous': {
102                'color': 'blue'
103            }
104        },
105        'graph': {
106            '': {},
107            'default': {
108                'color': 'black',
109                'fillcolor': 'white',
110                'style': 'solid'
111            },
112            'previous': {
113                'color': 'blue',
114                'fillcolor': 'azure2',
115                'style': 'filled'
116            },
117            'active': {
118                'color': 'red',
119                'fillcolor': 'darksalmon',
120                'style': 'filled'
121            },
122            'parallel': {
123                'color': 'black',
124                'fillcolor': 'white',
125                'style': 'dotted'
126            }
127        }
128    }
129
130    # model_graphs cannot be pickled. Omit them.
131    def __getstate__(self):
132        # self.pkl_graphs = [(g.markup, g.custom_styles) for g in self.model_graphs]
133        return {k: v for k, v in self.__dict__.items() if k not in self._pickle_blacklist}
134
135    def __setstate__(self, state):
136        self.__dict__.update(state)
137        self.model_graphs = {}  # reinitialize new model_graphs
138        for model in self.models:
139            try:
140                _ = self._get_graph(model, title=self.title)
141            except AttributeError as e:
142                _LOGGER.warning("Graph for model could not be initialized after pickling: %s", e)
143
144    def __init__(self, *args, **kwargs):
145        # remove graph config from keywords
146        self.title = kwargs.pop('title', 'State Machine')
147        self.show_conditions = kwargs.pop('show_conditions', False)
148        self.show_state_attributes = kwargs.pop('show_state_attributes', False)
149        # in MarkupMachine this switch is called 'with_auto_transitions'
150        # keep 'auto_transitions_markup' for backwards compatibility
151        kwargs['auto_transitions_markup'] = kwargs.get('auto_transitions_markup', False) or \
152            kwargs.pop('show_auto_transitions', False)
153        self.model_graphs = {}
154
155        # determine graph engine; if pygraphviz cannot be imported, fall back to graphviz
156        use_pygraphviz = kwargs.pop('use_pygraphviz', True)
157        if use_pygraphviz:
158            try:
159                import pygraphviz
160            except ImportError:
161                use_pygraphviz = False
162        self.graph_cls = self._init_graphviz_engine(use_pygraphviz)
163
164        _LOGGER.debug("Using graph engine %s", self.graph_cls)
165        _super(GraphMachine, self).__init__(*args, **kwargs)
166
167        # for backwards compatibility assign get_combined_graph to get_graph
168        # if model is not the machine
169        if not hasattr(self, 'get_graph'):
170            setattr(self, 'get_graph', self.get_combined_graph)
171
172    def _init_graphviz_engine(self, use_pygraphviz):
173        if use_pygraphviz:
174            try:
175                # state class needs to have a separator and machine needs to be a context manager
176                if hasattr(self.state_cls, 'separator') and hasattr(self, '__enter__'):
177                    from .diagrams_pygraphviz import NestedGraph as Graph
178                    self.machine_attributes.update(self.hierarchical_machine_attributes)
179                else:
180                    from .diagrams_pygraphviz import Graph
181                return Graph
182            except ImportError:
183                pass
184        if hasattr(self.state_cls, 'separator') and hasattr(self, '__enter__'):
185            from .diagrams_graphviz import NestedGraph as Graph
186            self.machine_attributes.update(self.hierarchical_machine_attributes)
187        else:
188            from .diagrams_graphviz import Graph
189        return Graph
190
191    def _get_graph(self, model, title=None, force_new=False, show_roi=False):
192        if force_new:
193            grph = self.graph_cls(self, title=title if title is not None else self.title)
194            self.model_graphs[model] = grph
195            try:
196                for state in _flatten(listify(getattr(model, self.model_attribute))):
197                    grph.set_node_style(self.dest if hasattr(state, 'name') else state, 'active')
198            except AttributeError:
199                _LOGGER.info("Could not set active state of diagram")
200        try:
201            m = self.model_graphs[model]
202        except KeyError:
203            _ = self._get_graph(model, title, force_new=True)
204            m = self.model_graphs[model]
205        m.roi_state = getattr(model, self.model_attribute) if show_roi else None
206        return m.get_graph(title=title)
207
208    def get_combined_graph(self, title=None, force_new=False, show_roi=False):
209        """ This method is currently equivalent to 'get_graph' of the first machine's model.
210        In future releases of transitions, this function will return a combined graph with active states
211        of all models.
212        Args:
213            title (str): Title of the resulting graph.
214            force_new (bool): If set to True, (re-)generate the model's graph.
215            show_roi (bool): If set to True, only render states that are active and/or can be reached from
216                the current state.
217        Returns: AGraph of the first machine's model.
218        """
219        _LOGGER.info('Returning graph of the first model. In future releases, this '
220                     'method will return a combined graph of all models.')
221        return self._get_graph(self.models[0], title, force_new, show_roi)
222
223    def add_model(self, model, initial=None):
224        models = listify(model)
225        super(GraphMachine, self).add_model(models, initial)
226        for mod in models:
227            mod = self if mod == 'self' else mod
228            if hasattr(mod, 'get_graph'):
229                raise AttributeError('Model already has a get_graph attribute. Graph retrieval cannot be bound.')
230            setattr(mod, 'get_graph', partial(self._get_graph, mod))
231            _ = mod.get_graph(title=self.title, force_new=True)  # initialises graph
232
233    def add_states(self, states, on_enter=None, on_exit=None,
234                   ignore_invalid_triggers=None, **kwargs):
235        """ Calls the base method and regenerates all models's graphs. """
236        _super(GraphMachine, self).add_states(states, on_enter=on_enter, on_exit=on_exit,
237                                              ignore_invalid_triggers=ignore_invalid_triggers, **kwargs)
238        for model in self.models:
239            model.get_graph(force_new=True)
240
241    def add_transition(self, trigger, source, dest, conditions=None,
242                       unless=None, before=None, after=None, prepare=None, **kwargs):
243        """ Calls the base method and regenerates all models's graphs. """
244        _super(GraphMachine, self).add_transition(trigger, source, dest, conditions=conditions, unless=unless,
245                                                  before=before, after=after, prepare=prepare, **kwargs)
246        for model in self.models:
247            model.get_graph(force_new=True)
248
249
250class BaseGraph(object):
251
252    def __init__(self, machine, title=None):
253        self.machine = machine
254        self.fsm_graph = None
255        self.roi_state = None
256        self.generate(title)
257
258    def _convert_state_attributes(self, state):
259        label = state.get('label', state['name'])
260        if self.machine.show_state_attributes:
261            if 'tags' in state:
262                label += ' [' + ', '.join(state['tags']) + ']'
263            if 'on_enter' in state:
264                label += r'\l- enter:\l  + ' + r'\l  + '.join(state['on_enter'])
265            if 'on_exit' in state:
266                label += r'\l- exit:\l  + ' + r'\l  + '.join(state['on_exit'])
267            if 'timeout' in state:
268                label += r'\l- timeout(' + state['timeout'] + 's)  -> (' + ', '.join(state['on_timeout']) + ')'
269        return label
270
271    def _transition_label(self, tran):
272        edge_label = tran.get('label', tran['trigger'])
273        if 'dest' not in tran:
274            edge_label += " [internal]"
275        if self.machine.show_conditions and any(prop in tran for prop in ['conditions', 'unless']):
276            x = '{edge_label} [{conditions}]'.format(
277                edge_label=edge_label,
278                conditions=' & '.join(tran.get('conditions', []) + ['!' + u for u in tran.get('unless', [])]),
279            )
280            return x
281        return edge_label
282
283    def _get_global_name(self, path):
284        if path:
285            state = path.pop(0)
286            with self.machine(state):
287                return self._get_global_name(path)
288        else:
289            return self.machine.get_global_name()
290
291    def _get_elements(self):
292        states = []
293        transitions = []
294        try:
295            markup = self.machine.get_markup_config()
296            q = [([], markup)]
297
298            while q:
299                prefix, scope = q.pop(0)
300                for transition in scope.get('transitions', []):
301                    if prefix:
302                        t = copy.copy(transition)
303                        t['source'] = self.machine.state_cls.separator.join(prefix + [t['source']])
304                        if 'dest' in t:  # don't do this for internal transitions
305                            t['dest'] = self.machine.state_cls.separator.join(prefix + [t['dest']])
306                    else:
307                        t = transition
308                    transitions.append(t)
309                for state in scope.get('children', []) + scope.get('states', []):
310                    if not prefix:
311                        s = state
312                        states.append(s)
313
314                    ini = state.get('initial', [])
315                    if not isinstance(ini, list):
316                        ini = ini.name if hasattr(ini, 'name') else ini
317                        t = dict(trigger='',
318                                 source=self.machine.state_cls.separator.join(prefix + [state['name']]) + '_anchor',
319                                 dest=self.machine.state_cls.separator.join(prefix + [state['name'], ini]))
320                        transitions.append(t)
321                    if state.get('children', []):
322                        q.append((prefix + [state['name']], state))
323        except KeyError as e:
324            _LOGGER.error("Graph creation incomplete!")
325        return states, transitions
326
327
328def _flatten(item):
329    for elem in item:
330        if isinstance(elem, (list, tuple, set)):
331            for res in _flatten(elem):
332                yield res
333        else:
334            yield elem
335