1from transitions import Transition 2from transitions.extensions.markup import MarkupMachine 3from transitions.core import listify 4 5import logging 6from functools import partial 7import copy 8 9_LOGGER = logging.getLogger(__name__) 10_LOGGER.addHandler(logging.NullHandler()) 11 12# this is a workaround for dill issues when partials and super is used in conjunction 13# without it, Python 3.0 - 3.3 will not support pickling 14# https://github.com/pytransitions/transitions/issues/236 15_super = super 16 17 18class TransitionGraphSupport(Transition): 19 """ Transition used in conjunction with (Nested)Graphs to update graphs whenever a transition is 20 conducted. 21 """ 22 23 def __init__(self, *args, **kwargs): 24 label = kwargs.pop('label', None) 25 _super(TransitionGraphSupport, self).__init__(*args, **kwargs) 26 if label: 27 self.label = label 28 29 def _change_state(self, event_data): 30 graph = event_data.machine.model_graphs[event_data.model] 31 graph.reset_styling() 32 graph.set_previous_transition(self.source, self.dest, event_data.event.name) 33 _super(TransitionGraphSupport, self)._change_state(event_data) # pylint: disable=protected-access 34 graph = event_data.machine.model_graphs[event_data.model] # graph might have changed during change_event 35 for state in _flatten(listify(getattr(event_data.model, event_data.machine.model_attribute))): 36 graph.set_node_style(self.dest if hasattr(state, 'name') else state, 'active') 37 38 39class GraphMachine(MarkupMachine): 40 """ Extends transitions.core.Machine with graph support. 41 Is also used as a mixin for HierarchicalMachine. 42 Attributes: 43 _pickle_blacklist (list): Objects that should not/do not need to be pickled. 44 transition_cls (cls): TransitionGraphSupport 45 """ 46 47 _pickle_blacklist = ['model_graphs'] 48 transition_cls = TransitionGraphSupport 49 50 machine_attributes = { 51 'directed': 'true', 52 'strict': 'false', 53 'rankdir': 'LR', 54 } 55 56 hierarchical_machine_attributes = { 57 'rankdir': 'TB', 58 'rank': 'source', 59 'nodesep': '1.5', 60 'compound': 'true' 61 } 62 63 style_attributes = { 64 'node': { 65 '': {}, 66 'default': { 67 'style': 'rounded, filled', 68 'shape': 'rectangle', 69 'fillcolor': 'white', 70 'color': 'black', 71 'peripheries': '1' 72 }, 73 'inactive': { 74 'fillcolor': 'white', 75 'color': 'black', 76 'peripheries': '1' 77 }, 78 'parallel': { 79 'shape': 'rectangle', 80 'color': 'black', 81 'fillcolor': 'white', 82 'style': 'dashed, rounded, filled', 83 'peripheries': '1' 84 }, 85 'active': { 86 'color': 'red', 87 'fillcolor': 'darksalmon', 88 'peripheries': '2' 89 }, 90 'previous': { 91 'color': 'blue', 92 'fillcolor': 'azure2', 93 'peripheries': '1' 94 } 95 }, 96 'edge': { 97 '': {}, 98 'default': { 99 'color': 'black' 100 }, 101 'previous': { 102 'color': 'blue' 103 } 104 }, 105 'graph': { 106 '': {}, 107 'default': { 108 'color': 'black', 109 'fillcolor': 'white', 110 'style': 'solid' 111 }, 112 'previous': { 113 'color': 'blue', 114 'fillcolor': 'azure2', 115 'style': 'filled' 116 }, 117 'active': { 118 'color': 'red', 119 'fillcolor': 'darksalmon', 120 'style': 'filled' 121 }, 122 'parallel': { 123 'color': 'black', 124 'fillcolor': 'white', 125 'style': 'dotted' 126 } 127 } 128 } 129 130 # model_graphs cannot be pickled. Omit them. 131 def __getstate__(self): 132 # self.pkl_graphs = [(g.markup, g.custom_styles) for g in self.model_graphs] 133 return {k: v for k, v in self.__dict__.items() if k not in self._pickle_blacklist} 134 135 def __setstate__(self, state): 136 self.__dict__.update(state) 137 self.model_graphs = {} # reinitialize new model_graphs 138 for model in self.models: 139 try: 140 _ = self._get_graph(model, title=self.title) 141 except AttributeError as e: 142 _LOGGER.warning("Graph for model could not be initialized after pickling: %s", e) 143 144 def __init__(self, *args, **kwargs): 145 # remove graph config from keywords 146 self.title = kwargs.pop('title', 'State Machine') 147 self.show_conditions = kwargs.pop('show_conditions', False) 148 self.show_state_attributes = kwargs.pop('show_state_attributes', False) 149 # in MarkupMachine this switch is called 'with_auto_transitions' 150 # keep 'auto_transitions_markup' for backwards compatibility 151 kwargs['auto_transitions_markup'] = kwargs.get('auto_transitions_markup', False) or \ 152 kwargs.pop('show_auto_transitions', False) 153 self.model_graphs = {} 154 155 # determine graph engine; if pygraphviz cannot be imported, fall back to graphviz 156 use_pygraphviz = kwargs.pop('use_pygraphviz', True) 157 if use_pygraphviz: 158 try: 159 import pygraphviz 160 except ImportError: 161 use_pygraphviz = False 162 self.graph_cls = self._init_graphviz_engine(use_pygraphviz) 163 164 _LOGGER.debug("Using graph engine %s", self.graph_cls) 165 _super(GraphMachine, self).__init__(*args, **kwargs) 166 167 # for backwards compatibility assign get_combined_graph to get_graph 168 # if model is not the machine 169 if not hasattr(self, 'get_graph'): 170 setattr(self, 'get_graph', self.get_combined_graph) 171 172 def _init_graphviz_engine(self, use_pygraphviz): 173 if use_pygraphviz: 174 try: 175 # state class needs to have a separator and machine needs to be a context manager 176 if hasattr(self.state_cls, 'separator') and hasattr(self, '__enter__'): 177 from .diagrams_pygraphviz import NestedGraph as Graph 178 self.machine_attributes.update(self.hierarchical_machine_attributes) 179 else: 180 from .diagrams_pygraphviz import Graph 181 return Graph 182 except ImportError: 183 pass 184 if hasattr(self.state_cls, 'separator') and hasattr(self, '__enter__'): 185 from .diagrams_graphviz import NestedGraph as Graph 186 self.machine_attributes.update(self.hierarchical_machine_attributes) 187 else: 188 from .diagrams_graphviz import Graph 189 return Graph 190 191 def _get_graph(self, model, title=None, force_new=False, show_roi=False): 192 if force_new: 193 grph = self.graph_cls(self, title=title if title is not None else self.title) 194 self.model_graphs[model] = grph 195 try: 196 for state in _flatten(listify(getattr(model, self.model_attribute))): 197 grph.set_node_style(self.dest if hasattr(state, 'name') else state, 'active') 198 except AttributeError: 199 _LOGGER.info("Could not set active state of diagram") 200 try: 201 m = self.model_graphs[model] 202 except KeyError: 203 _ = self._get_graph(model, title, force_new=True) 204 m = self.model_graphs[model] 205 m.roi_state = getattr(model, self.model_attribute) if show_roi else None 206 return m.get_graph(title=title) 207 208 def get_combined_graph(self, title=None, force_new=False, show_roi=False): 209 """ This method is currently equivalent to 'get_graph' of the first machine's model. 210 In future releases of transitions, this function will return a combined graph with active states 211 of all models. 212 Args: 213 title (str): Title of the resulting graph. 214 force_new (bool): If set to True, (re-)generate the model's graph. 215 show_roi (bool): If set to True, only render states that are active and/or can be reached from 216 the current state. 217 Returns: AGraph of the first machine's model. 218 """ 219 _LOGGER.info('Returning graph of the first model. In future releases, this ' 220 'method will return a combined graph of all models.') 221 return self._get_graph(self.models[0], title, force_new, show_roi) 222 223 def add_model(self, model, initial=None): 224 models = listify(model) 225 super(GraphMachine, self).add_model(models, initial) 226 for mod in models: 227 mod = self if mod == 'self' else mod 228 if hasattr(mod, 'get_graph'): 229 raise AttributeError('Model already has a get_graph attribute. Graph retrieval cannot be bound.') 230 setattr(mod, 'get_graph', partial(self._get_graph, mod)) 231 _ = mod.get_graph(title=self.title, force_new=True) # initialises graph 232 233 def add_states(self, states, on_enter=None, on_exit=None, 234 ignore_invalid_triggers=None, **kwargs): 235 """ Calls the base method and regenerates all models's graphs. """ 236 _super(GraphMachine, self).add_states(states, on_enter=on_enter, on_exit=on_exit, 237 ignore_invalid_triggers=ignore_invalid_triggers, **kwargs) 238 for model in self.models: 239 model.get_graph(force_new=True) 240 241 def add_transition(self, trigger, source, dest, conditions=None, 242 unless=None, before=None, after=None, prepare=None, **kwargs): 243 """ Calls the base method and regenerates all models's graphs. """ 244 _super(GraphMachine, self).add_transition(trigger, source, dest, conditions=conditions, unless=unless, 245 before=before, after=after, prepare=prepare, **kwargs) 246 for model in self.models: 247 model.get_graph(force_new=True) 248 249 250class BaseGraph(object): 251 252 def __init__(self, machine, title=None): 253 self.machine = machine 254 self.fsm_graph = None 255 self.roi_state = None 256 self.generate(title) 257 258 def _convert_state_attributes(self, state): 259 label = state.get('label', state['name']) 260 if self.machine.show_state_attributes: 261 if 'tags' in state: 262 label += ' [' + ', '.join(state['tags']) + ']' 263 if 'on_enter' in state: 264 label += r'\l- enter:\l + ' + r'\l + '.join(state['on_enter']) 265 if 'on_exit' in state: 266 label += r'\l- exit:\l + ' + r'\l + '.join(state['on_exit']) 267 if 'timeout' in state: 268 label += r'\l- timeout(' + state['timeout'] + 's) -> (' + ', '.join(state['on_timeout']) + ')' 269 return label 270 271 def _transition_label(self, tran): 272 edge_label = tran.get('label', tran['trigger']) 273 if 'dest' not in tran: 274 edge_label += " [internal]" 275 if self.machine.show_conditions and any(prop in tran for prop in ['conditions', 'unless']): 276 x = '{edge_label} [{conditions}]'.format( 277 edge_label=edge_label, 278 conditions=' & '.join(tran.get('conditions', []) + ['!' + u for u in tran.get('unless', [])]), 279 ) 280 return x 281 return edge_label 282 283 def _get_global_name(self, path): 284 if path: 285 state = path.pop(0) 286 with self.machine(state): 287 return self._get_global_name(path) 288 else: 289 return self.machine.get_global_name() 290 291 def _get_elements(self): 292 states = [] 293 transitions = [] 294 try: 295 markup = self.machine.get_markup_config() 296 q = [([], markup)] 297 298 while q: 299 prefix, scope = q.pop(0) 300 for transition in scope.get('transitions', []): 301 if prefix: 302 t = copy.copy(transition) 303 t['source'] = self.machine.state_cls.separator.join(prefix + [t['source']]) 304 if 'dest' in t: # don't do this for internal transitions 305 t['dest'] = self.machine.state_cls.separator.join(prefix + [t['dest']]) 306 else: 307 t = transition 308 transitions.append(t) 309 for state in scope.get('children', []) + scope.get('states', []): 310 if not prefix: 311 s = state 312 states.append(s) 313 314 ini = state.get('initial', []) 315 if not isinstance(ini, list): 316 ini = ini.name if hasattr(ini, 'name') else ini 317 t = dict(trigger='', 318 source=self.machine.state_cls.separator.join(prefix + [state['name']]) + '_anchor', 319 dest=self.machine.state_cls.separator.join(prefix + [state['name'], ini])) 320 transitions.append(t) 321 if state.get('children', []): 322 q.append((prefix + [state['name']], state)) 323 except KeyError as e: 324 _LOGGER.error("Graph creation incomplete!") 325 return states, transitions 326 327 328def _flatten(item): 329 for elem in item: 330 if isinstance(elem, (list, tuple, set)): 331 for res in _flatten(elem): 332 yield res 333 else: 334 yield elem 335