1#
2#     This file is part of CasADi.
3#
4#     CasADi -- A symbolic framework for dynamic optimization.
5#     Copyright (C) 2010-2014 Joel Andersson, Joris Gillis, Moritz Diehl,
6#                             K.U. Leuven. All rights reserved.
7#     Copyright (C) 2011-2014 Greg Horn
8#
9#     CasADi is free software; you can redistribute it and/or
10#     modify it under the terms of the GNU Lesser General Public
11#     License as published by the Free Software Foundation; either
12#     version 3 of the License, or (at your option) any later version.
13#
14#     CasADi is distributed in the hope that it will be useful,
15#     but WITHOUT ANY WARRANTY; without even the implied warranty of
16#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17#     Lesser General Public License for more details.
18#
19#     You should have received a copy of the GNU Lesser General Public
20#     License along with CasADi; if not, write to the Free Software
21#     Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
22#
23#
24from casadi import SX, MX, print_operator
25import casadi as C
26
27try:
28  from pydot import pydot
29except:
30  try:
31    import pydot
32  except:
33    raise Exception("To use the functionality of casadi.tools.graph, you need to have pydot Installed. Try `easy_install pydot`.")
34
35#import ipdb
36
37def hashcompare(self,other):
38  return cmp(hash(self),hash(other))
39
40
41def equality(self,other):
42  return hash(self)==hash(other)
43
44def getDeps(s):
45  deps = []
46  if not(hasattr(s,'n_dep')): return deps
47  for k in range(s.n_dep()):
48    d = s.dep(k)
49    d.__class__.__cmp__ = hashcompare
50    d.__class__.__eq__  = equality
51    deps.append(d)
52  return deps
53
54def addDependency(master,slave,dep={},invdep={}):
55  #print master.__hash__() , " => ", slave.__hash__(), "   ", master , " => ", slave
56  if master in dep:
57    dep[master].add(slave)
58  else:
59    dep[master]= set([slave])
60  if slave in invdep:
61    invdep[slave].add(master)
62  else:
63    invdep[slave] = set([master])
64
65def addDependencies(master,slaves,dep={},invdep={}):
66  for slave in slaves:
67    addDependency(master,slave,dep = dep,invdep = invdep)
68  for slave in slaves:
69    dependencyGraph(slave,dep = dep,invdep = invdep)
70
71def is_leaf(s):
72  return s.is_leaf()
73  #return s.is_scalar(True) and (s.is_constant() or s.is_symbolic())
74
75def dependencyGraph(s,dep = {},invdep = {}):
76  if isinstance(s,SX):
77    if s.is_scalar(True):
78      if not(is_leaf(s)):
79        addDependencies(s,getDeps(s),dep = dep,invdep = invdep)
80    else:
81      addDependencies(s,s.nonzeros(),dep = dep,invdep = invdep)
82  elif isinstance(s,MX):
83    addDependencies(s,getDeps(s),dep = dep,invdep = invdep)
84  return (dep,invdep)
85
86class DotArtist:
87  sparsitycol = "#eeeeee"
88  def __init__(self,s,dep={},invdep={},graph=None,artists={},**kwargs):
89    self.s = s
90    self.dep = dep
91    self.invdep = invdep
92    self.graph = graph
93    self.artists = artists
94    self.kwargs = kwargs
95
96  def hasPorts(self):
97    return False
98
99  def drawSparsity(self,s,id=None,depid=None,graph=None,nzlabels=None):
100    if id is None:
101      id = str(s.__hash__())
102    if depid is None:
103      depid = str(s.dep(0).__hash__())
104    if graph is None:
105      graph = self.graph
106    sp = s.sparsity()
107    deps = getDeps(s)
108    if nzlabels is None:
109      nzlabels = list(map(str,list(range(sp.nnz()))))
110    nzlabelcounter = 0
111    if s.nnz()==s.numel():
112      graph.add_node(pydot.Node(id,label="%d x %d" % (s.size1(),s.size2()),shape='rectangle',color=self.sparsitycol,style="filled"))
113    else:
114      label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0">'
115      label+="<TR><TD COLSPAN='%d'><font color='#666666'>%s</font></TD></TR>" % (s.size2(), s.dim())
116      for i in range(s.size1()):
117        label+="<TR>"
118        for j in range(s.size2()):
119          k = sp.get_nz(i,j)
120          if k==-1:
121            label+="<TD>.</TD>"
122          else:
123            label+="<TD PORT='f%d' BGCOLOR='%s'>%s</TD>" % (k,self.sparsitycol,nzlabels[nzlabelcounter])
124            nzlabelcounter +=1
125        label+="</TR>"
126      label+="</TABLE>>"
127      graph.add_node(pydot.Node(id,label=label,shape='plaintext'))
128    graph.add_edge(pydot.Edge(depid,id))
129
130class MXSymbolicArtist(DotArtist):
131  def hasPorts(self):
132    return True
133
134  def draw(self):
135    s = self.s
136    graph = self.graph
137    sp = s.sparsity()
138    row = sp.row()
139    col = "#990000"
140    if s.nnz() == s.numel() and s.nnz()==1:
141      # The Matrix grid is represented by a html table with 'ports'
142      graph.add_node(pydot.Node(str(self.s.__hash__())+":f0",label=s.name(),shape='rectangle',color=col))
143    else:
144       # The Matrix grid is represented by a html table with 'ports'
145      label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" COLOR="%s">' % col
146      label+="<TR><TD COLSPAN='%d'>%s: <font color='#666666'>%s</font></TD></TR>" % (s.size2(),s.name(), s.dim())
147      for i in range(s.size1()):
148        label+="<TR>"
149        for j in range(s.size2()):
150          k = sp.get_nz(i,j)
151          if k==-1:
152            label+="<TD>.</TD>"
153          else:
154            label+="<TD PORT='f%d' BGCOLOR='#eeeeee'> <font color='#666666'>(%d,%d | %d)</font> </TD>" % (k,i,j,k)
155        label+="</TR>"
156      label+="</TABLE>>"
157      graph.add_node(pydot.Node(str(self.s.__hash__()),label=label,shape='plaintext'))
158
159# class MXMappingArtist(DotArtist):
160#   def draw(self):
161#     s = self.s
162#     graph = self.graph
163#     sp = s.sparsity()
164#     row = sp.row()
165
166
167#     # Note: due to Mapping restructuring, this is no longer efficient code
168#     deps = getDeps(s)
169
170#     depind = s.depInd()
171#     nzmap = sum([s.mapping(i) for i in range(len(deps))])
172
173#     for k,d in enumerate(deps):
174#       candidates = map(hash,filter(lambda i: i.isMapping(),self.invdep[d]))
175#       candidates.sort()
176#       if candidates[0] == hash(s):
177#         graph.add_edge(pydot.Edge(str(d.__hash__()),"mapinput" + str(d.__hash__())))
178
179#     graph = pydot.Cluster('clustertest' + str(s.__hash__()), rank='max', label='Mapping')
180#     self.graph.add_subgraph(graph)
181
182
183
184#     colors = ['#eeeecc','#ccccee','#cceeee','#eeeecc','#eeccee','#cceecc']
185
186#     for k,d in enumerate(deps):
187#       spd = d.sparsity()
188#       #ipdb.set_trace()
189#       # The Matrix grid is represented by a html table with 'ports'
190#       candidates = map(hash,filter(lambda i: i.isMapping(),self.invdep[d]))
191#       candidates.sort()
192#       if candidates[0] == hash(s):
193#         label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" COLOR="#0000aa">'
194#         if not(d.numel()==1 and d.numel()==d.nnz()):
195#           label+="<TR><TD COLSPAN='%d' BGCOLOR='#dddddd'><font>%s</font></TD></TR>" % (d.size2(), d.dim())
196#         for i in range(d.size1()):
197#           label+="<TR>"
198#           for j in range(d.size2()):
199#             kk = spd.get_nz(i,j)
200#             if kk==-1:
201#               label+="<TD>.</TD>"
202#             else:
203#               label+="<TD PORT='f%d' BGCOLOR='%s'> <font color='#666666'>%d</font> </TD>" % (kk,colors[k],kk)
204#           label+="</TR>"
205#         label+="</TABLE>>"
206#         graph.add_node(pydot.Node("mapinput" + str(d.__hash__()),label=label,shape='plaintext'))
207#       graph.add_edge(pydot.Edge("mapinput" + str(d.__hash__()),str(s.__hash__())))
208
209#     # The Matrix grid is represented by a html table with 'ports'
210#     label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0">'
211#     if not(s.numel()==1 and s.numel()==s.nnz()):
212#       label+="<TR><TD COLSPAN='%d'><font color='#666666'>%s</font></TD></TR>" % (s.size2(), s.dim())
213#     for i in range(s.size1()):
214#       label+="<TR>"
215#       for j in range(s.size2()):
216#         k = sp.get_nz(i,j)
217#         if k==-1:
218#           label+="<TD>.</TD>"
219#         else:
220#           label+="<TD PORT='f%d' BGCOLOR='%s'> <font color='#666666'>%d</font> </TD>" % (k,colors[depind[k]],nzmap[k])
221#       label+="</TR>"
222#     label+="</TABLE>>"
223#     graph.add_node(pydot.Node(str(self.s.__hash__()),label=label,shape='plaintext'))
224
225
226class MXEvaluationArtist(DotArtist):
227  def draw(self):
228    s = self.s
229    graph = self.graph
230    sp = s.sparsity()
231    row = sp.row()
232
233
234    deps = getDeps(s)
235
236    f = s.which_function()
237
238    for k,d in enumerate(deps):
239      graph.add_edge(pydot.Edge(str(d.__hash__()),"funinput" + str(s.__hash__())+ ":f%d" % k,rankdir="LR"))
240
241    graph = pydot.Cluster(str(s.__hash__()), rank='max', label='Function:\n %s' % f.name())
242    self.graph.add_subgraph(graph)
243
244    s = (" %d inputs: |" % f.n_in()) + " | ".join("<f%d> %d" % (i,i) for i in range(f.n_in()))
245    graph.add_node(pydot.Node("funinput" + str(self.s.__hash__()),label=s,shape='Mrecord'))
246
247    s = (" %d outputs: |" % f.n_out())+ " | ".join("<f%d> %d" % (i,i) for i in range(f.n_out()))
248    graph.add_node(pydot.Node(str(self.s.__hash__()),label=s,shape='Mrecord'))
249
250
251class MXConstantArtist(DotArtist):
252  def hasPorts(self):
253    return True
254  def draw(self):
255    s = self.s
256    graph = self.graph
257    sp = s.sparsity()
258    row = sp.row()
259    M = s.to_DM()
260    col = "#009900"
261    if s.nnz() == s.numel() and s.nnz() == 1:
262      graph.add_node(pydot.Node(str(self.s.__hash__())+":f0",label=str(M[0,0]),shape='rectangle',color=col))
263    else:
264      # The Matrix grid is represented by a html table with 'ports'
265      label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" COLOR="%s">' % col
266      label+="<TR><TD COLSPAN='%d'><font color='#666666'>%s</font></TD></TR>" % (s.size2(), s.dim())
267      if not "max_numel" in self.kwargs or s.numel() < self.kwargs["max_numel"]:
268        for i in range(s.size1()):
269          label+="<TR>"
270          for j in range(s.size2()):
271            k = sp.get_nz(i,j)
272            if k==-1:
273              label+="<TD>.</TD>"
274            else:
275              label+="<TD PORT='f%d' BGCOLOR='#eeeeee'> %s </TD>" % (k,M[i,j])
276          label+="</TR>"
277      label+="</TABLE>>"
278      graph.add_node(pydot.Node(str(self.s.__hash__()),label=label,shape='plaintext'))
279
280class MXGenericArtist(DotArtist):
281  def draw(self):
282    k = self.s
283    graph = self.graph
284    dep = getDeps(k)
285
286    show_sp = not(all([d.sparsity()==k.sparsity() for d in dep]))
287
288    if show_sp:
289      op = "op"
290      self.drawSparsity(k,depid=op + str(k.__hash__()))
291    else:
292      op = ""
293
294    if len(dep)>1:
295      # Non-commutative operators are represented by 'record' shapes.
296      # The dependencies have different 'ports' where arrows should arrive.
297      s = print_operator(self.s,["| <f%d> | " %i for i in range(len(dep))])
298      if s.startswith("(|") and s.endswith("|)"):
299        s=s[2:-2]
300
301      graph.add_node(pydot.Node(op + str(k.__hash__()),label=s,shape='Mrecord'))
302      for i,n in enumerate(dep):
303        graph.add_edge(pydot.Edge(str(n.__hash__()),op + str(k.__hash__())+":f%d" % i))
304    else:
305      s = print_operator(k,["."])
306      self.graph.add_node(pydot.Node(op + str(k.__hash__()),label=s,shape='oval'))
307      for i,n in enumerate(dep):
308        self.graph.add_edge(pydot.Edge(str(n.__hash__()),op + str(k.__hash__())))
309
310class MXGetNonzerosArtist(DotArtist):
311  def draw(self):
312    s = self.s
313    graph = self.graph
314    n = getDeps(s)[0]
315
316
317    show_sp = not(s.nnz() == s.numel() and s.nnz() == 1)
318
319    if show_sp:
320      op = "op"
321      self.drawSparsity(s,depid=op + str(s.__hash__()))
322    else:
323      op = ""
324
325    sp = s.sparsity()
326    row = sp.row()
327    M = s.mapping()
328    col = "#333333"
329    if s.nnz() == s.numel() and s.nnz() == 1:
330      graph.add_node(pydot.Node(op+str(s.__hash__())+":f0",label="[%s]" % str(M[0,0]),shape='rectangle',style="filled",fillcolor='#eeeeff'))
331    else:
332      # The Matrix grid is represented by a html table with 'ports'
333      label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" COLOR="%s">' % col
334      label+="<TR><TD COLSPAN='%d' PORT='entry'>getNonzeros</TD></TR>" % (s.size2())
335      if not "max_nnz" in self.kwargs or s.nnz() < self.kwargs["max_nnz"]:
336        for i in range(s.size1()):
337          label+="<TR>"
338          for j in range(s.size2()):
339            k = sp.get_nz(i,j)
340            if k==-1:
341              label+="<TD>.</TD>"
342            else:
343              label+="<TD PORT='f%d' BGCOLOR='#eeeeff'> %s </TD>" % (k,M[i,j])
344          label+="</TR>"
345      label+="</TABLE>>"
346      graph.add_node(pydot.Node(op+str(s.__hash__()),label=label,shape='plaintext'))
347    self.graph.add_edge(pydot.Edge(str(n.__hash__()),op+str(s.__hash__())))
348
349class MXSetNonzerosArtist(DotArtist):
350  def draw(self):
351    s = self.s
352    graph = self.graph
353    entry = getDeps(s)[0]
354    target = getDeps(s)[1]
355
356
357    show_sp = not(all([d.sparsity()==s.sparsity() for d in getDeps(s)]))
358
359    if show_sp:
360      op = "op"
361      self.drawSparsity(s,depid=op + str(s.__hash__()))
362    else:
363      op = ""
364
365    sp = target.sparsity()
366    row = sp.row()
367    M = list(s.mapping())
368    Mk = 0
369    col = "#333333"
370
371    # The Matrix grid is represented by a html table with 'ports'
372    label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" COLOR="%s">' % col
373    label+="<TR><TD COLSPAN='%d' PORT='entry'>setNonzeros</TD></TR>" % (s.size2())
374    for i in range(s.size1()):
375      label+="<TR>"
376      for j in range(s.size2()):
377        k = entry.sparsity().get_nz(i,j)
378        if k==-1 or Mk>= len(M) or k != M[Mk]:
379          label+="<TD>.</TD>"
380          if Mk< len(M)-1 and M[Mk]==-1 and k!=-1: Mk+=1
381        else:
382          label+="<TD PORT='f%d' BGCOLOR='#eeeeff'> %s </TD>" % (Mk,Mk)
383          Mk+=1
384      label+="</TR>"
385    label+="</TABLE>>"
386    graph.add_node(pydot.Node(op+str(s.__hash__()),label=label,shape='plaintext'))
387    self.graph.add_edge(pydot.Edge(str(entry.__hash__()),op+str(s.__hash__())+':entry'))
388    self.graph.add_edge(pydot.Edge(str(target.__hash__()),op+str(s.__hash__())))
389
390
391class MXAddNonzerosArtist(DotArtist):
392  def draw(self):
393    s = self.s
394    graph = self.graph
395    entry = getDeps(s)[0]
396    target = getDeps(s)[1]
397
398    show_sp = not(all([d.sparsity()==s.sparsity() for d in getDeps(s)]))
399
400    if show_sp:
401      op = "op"
402      self.drawSparsity(s,depid=op + str(s.__hash__()))
403    else:
404      op = ""
405
406    sp = target.sparsity()
407    row = sp.row()
408    M = list(s.mapping())
409    Mk = 0
410    col = "#333333"
411
412    # The Matrix grid is represented by a html table with 'ports'
413    label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" COLOR="%s">' % col
414    label+="<TR><TD COLSPAN='%d' PORT='entry'>addNonzeros</TD></TR>" % (s.size2())
415    for i in range(s.size1()):
416      label+="<TR>"
417      for j in range(s.size2()):
418        k = sp.get_nz(i,j)
419        if k==-1 or Mk>= len(M) or k != M[Mk]:
420          label+="<TD>.</TD>"
421          if Mk< len(M)-1 and M[Mk]==-1 and k!=-1: Mk+=1
422        else:
423          label+="<TD PORT='f%d' BGCOLOR='#eeeeff'> %s </TD>" % (Mk,Mk)
424          Mk+=1
425      label+="</TR>"
426    label+="</TABLE>>"
427    graph.add_node(pydot.Node(op+str(s.__hash__()),label=label,shape='plaintext'))
428    self.graph.add_edge(pydot.Edge(str(entry.__hash__()),op+str(s.__hash__())+':entry'))
429    self.graph.add_edge(pydot.Edge(str(target.__hash__()),op+str(s.__hash__())))
430
431
432class MXOperationArtist(DotArtist):
433  def draw(self):
434    k = self.s
435    graph = self.graph
436    dep = getDeps(k)
437
438    show_sp = True
439
440    if k.is_unary() and dep[0].sparsity()==k.sparsity():
441      show_sp = False
442    if k.is_binary() and dep[0].sparsity()==k.sparsity() and dep[1].sparsity()==k.sparsity():
443      show_sp = False
444
445    if show_sp:
446      op = "op"
447      self.drawSparsity(k,depid=op + str(k.__hash__()))
448    else:
449      op = ""
450
451    if not(k.is_commutative()):
452      # Non-commutative operators are represented by 'record' shapes.
453      # The dependencies have different 'ports' where arrows should arrive.
454      s = print_operator(self.s,["| <f0> | ", " | <f1> |"])
455      if s.startswith("(|") and s.endswith("|)"):
456        s=s[2:-2]
457
458      graph.add_node(pydot.Node(op + str(k.__hash__()),label=s,shape='Mrecord'))
459      for i,n in enumerate(dep):
460        graph.add_edge(pydot.Edge(str(n.__hash__()),op + str(k.__hash__())+":f%d" % i))
461    else:
462     # Commutative operators can be represented more compactly as 'oval' shapes.
463      s = print_operator(k,[".", "."])
464      if s.startswith("(.") and s.endswith(".)"):
465        s=s[2:-2]
466      if s.startswith("(") and s.endswith(")"):
467        s=s[1:-1]
468      self.graph.add_node(pydot.Node(op + str(k.__hash__()),label=s,shape='oval'))
469      for i,n in enumerate(dep):
470        self.graph.add_edge(pydot.Edge(str(n.__hash__()),op + str(k.__hash__())))
471
472class MXIfTestArtist(DotArtist):
473  def draw(self):
474    k = self.s
475    graph = self.graph
476    dep = getDeps(k)
477
478    show_sp = True
479
480    s = "<f0> ? | <f1> true"
481
482    graph.add_node(pydot.Node(str(k.__hash__()),label=s,shape='Mrecord'))
483    for i,n in enumerate(dep):
484      graph.add_edge(pydot.Edge(str(n.__hash__()),str(k.__hash__())+":f%d" % i))
485
486class MXDensificationArtist(DotArtist):
487  def draw(self):
488    k = self.s
489    graph = self.graph
490    dep = getDeps(k)
491
492    self.graph.add_node(pydot.Node(str(k.__hash__()),label="densify(.)",shape='oval'))
493    self.graph.add_edge(pydot.Edge(str(dep[0].__hash__()),str(k.__hash__())))
494
495class MXNormArtist(DotArtist):
496  def draw(self):
497    k = self.s
498    graph = self.graph
499    dep = getDeps(k)
500    s = print_operator(k,[".", "."])
501    self.graph.add_node(pydot.Node(str(k.__hash__()),label=s,shape='oval'))
502    self.graph.add_edge(pydot.Edge(str(dep[0].__hash__()),str(k.__hash__())))
503
504class MXEvaluationOutputArtist(DotArtist):
505  def draw(self):
506    k = self.s
507
508    self.drawSparsity(k,depid=str(hash(k.dep(0))) + ":f%d" % k.which_output())
509
510
511class MXMultiplicationArtist(DotArtist):
512  def draw(self):
513    k = self.s
514    graph = self.graph
515    dep = getDeps(k)
516
517    # Non-commutative operators are represented by 'record' shapes.
518    # The dependencies have different 'ports' where arrows should arrive.
519    s = "mul(| <f0> | , | <f1> | )"
520
521    graph.add_node(pydot.Node(str(k.__hash__()),label=s,shape='Mrecord'))
522    for i,n in enumerate(dep):
523      graph.add_edge(pydot.Edge(str(n.__hash__()),str(k.__hash__())+":f%d" % i))
524
525class SXArtist(DotArtist):
526  def draw(self):
527    s = self.s
528    graph = self.graph
529    sp = s.sparsity()
530    row = sp.row()
531
532    # The Matrix grid is represented by a html table with 'ports'
533    label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0">'
534    for i in range(s.size1()):
535      label+="<TR>"
536      for j in range(s.size2()):
537        k = sp.get_nz(i,j)
538        if k==-1:
539          label+="<TD>.</TD>"
540        else:
541          sx = s.nz[k]
542          if self.shouldEmbed(sx):
543            label+="<TD BGCOLOR='#eeeeee'>%s</TD>" % str(sx)
544          else:
545            self.graph.add_edge(pydot.Edge(str(sx.__hash__()),"%s:f%d" % (str(self.s.__hash__()), k)))
546            label+="<TD PORT='f%d' BGCOLOR='#eeeeee'> <font color='#666666'>(%d,%d|%d)</font> </TD>" % (k,i,j,k)
547      label+="</TR>"
548    label+="</TABLE>>"
549    graph.add_node(pydot.Node(str(self.s.__hash__()),label=label,shape='plaintext'))
550
551  def shouldEmbed(self,sx):
552    return len(self.invdep[sx]) == 1 and sx.is_leaf()
553
554class SXLeafArtist(DotArtist):
555  def draw(self):
556    if len(self.invdep[self.s]) == 1:
557      master = list(self.invdep[self.s])[0]
558      if hasattr(self.artists[master],'shouldEmbed'):
559        if self.artists[master].shouldEmbed(self.s):
560          return
561    style = "solid" # Symbolic nodes are represented box'es
562    if self.s.is_constant():
563      style = "bold" # Constants are represented by bold box'es
564    self.graph.add_node(pydot.Node(str(self.s.__hash__()),label=str(self.s),shape="box",style=style))
565
566class SXNonLeafArtist(DotArtist):
567  def draw(self):
568    k = self.s
569    graph = self.graph
570    dep = getDeps(k)
571    if not(k.is_commutative()):
572      # Non-commutative operators are represented by 'record' shapes.
573      # The dependencies have different 'ports' where arrows should arrive.
574      if len(dep)==2:
575        s = print_operator(self.s,["| <f0> | ", " | <f1> |"])
576      else:
577        s = print_operator(self.s,["| <f0> | "])
578      if s.startswith("(|") and s.endswith("|)"):
579        s=s[2:-2]
580
581      graph.add_node(pydot.Node(str(k.__hash__()),label=s,shape='Mrecord'))
582      for i,n in enumerate(dep):
583        graph.add_edge(pydot.Edge(str(n.__hash__()),str(k.__hash__())+":f%d" % i))
584    else:
585     # Commutative operators can be represented more compactly as 'oval' shapes.
586      s = print_operator(k,[".", "."])
587      if s.startswith("(.") and s.endswith(".)"):
588        s=s[2:-2]
589      if s.startswith("(") and s.endswith(")"):
590        s=s[1:-1]
591      self.graph.add_node(pydot.Node(str(k.__hash__()),label=s,shape='oval'))
592      for i,n in enumerate(dep):
593        self.graph.add_edge(pydot.Edge(str(n.__hash__()),str(k.__hash__())))
594
595
596
597def createArtist(node,dep={},invdep={},graph=None,artists={},**kwargs):
598  if isinstance(node,SX):
599    if node.is_scalar(True):
600      if is_leaf(node):
601        return SXLeafArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs)
602      else:
603        return SXNonLeafArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs)
604    else:
605      return SXArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs)
606
607
608  elif isinstance(node,MX):
609    if node.is_symbolic():
610      return MXSymbolicArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs)
611    elif node.is_binary() or node.is_unary():
612      return MXOperationArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs)
613    elif node.is_constant():
614      return MXConstantArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs)
615    elif node.is_call():
616      return MXEvaluationArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs)
617    elif node.is_output():
618      return MXEvaluationOutputArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs)
619    elif node.is_norm():
620      return MXNormArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs)
621    elif node.is_op(C.OP_GETNONZEROS):
622      return MXGetNonzerosArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs)
623    elif node.is_op(C.OP_SETNONZEROS):
624      return MXSetNonzerosArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs)
625    elif node.is_op(C.OP_ADDNONZEROS):
626      return MXAddNonzerosArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs)
627    else:
628      return MXGenericArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs)
629  else:
630    raise Exception("Cannot create artist for %s" % str(type(s)))
631
632def dotgraph(s,direction="BT",**kwargs):
633  """
634  Creates and returns a pydot graph structure that represents an SX.
635
636  direction   one of "BT", "LR", "TB", "RL"
637  """
638
639  try:
640    def getHashSX(e):
641      if e.is_scalar(True):
642        try:
643          return e.element_hash()
644        except:
645          return SX__hash__backup(e)
646      else:
647        return 0
648
649    SX__hash__backup = SX.__hash__
650    SX.__hash__ = getHashSX
651
652    # Get the dependencies and inverse dependencies in a dict
653    dep, invdep = dependencyGraph(s,{},{})
654
655    allnodes = set(dep.keys()).union(set(invdep.keys()))
656
657    #print "a", set(dep.keys()), [i.__hash__() for i in dep.keys()]
658    #print "b", set(invdep.keys()), [i.__hash__() for i in invdep.keys()]
659    #print "allnodes", allnodes, [i.__hash__() for i in allnodes]
660
661    #return None
662
663    artists = {}
664
665    graph = pydot.Dot('G', graph_type='digraph',rankdir=direction)
666
667    for node in allnodes:
668      artists[node] = createArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs)
669
670    for artist in artists.values():
671      if artist is None: continue
672      artist.draw()
673
674    open('source.dot','w').write(graph.to_string())
675  finally:
676    SX.__hash__ = SX__hash__backup
677  return graph
678
679
680def dotsave(s,format='ps',filename="temp",direction="RL",**kwargs):
681  """
682  Make a drawing of an SX and save it.
683
684  format can be one of:
685    dot canon cmap cmapx cmapx_np dia dot fig gd gd2 gif hpgl imap imap_np
686    ismap jpe jpeg jpg mif mp pcl pdf pic plain plain-ext png ps ps2 raw
687    svg svgz vml vmlz vrml vtx wbmp xdot xlib
688
689  direction   one of "BT", "LR", "TB", "RL"
690
691  """
692  g = dotgraph(s,direction=direction,**kwargs)
693  if hasattr(g,'write_'+format):
694    getattr(g,'write_'+format)(filename)
695  else:
696    s = "Unknown format '%s'. Please pick one of the following:\n" % format
697    l = ['dot']
698    for n in dir(g):
699      if n.startswith("write_"):
700        l.append(n[6:])
701    s+= " ".join(l)
702    raise Exception(s)
703
704def dotdraw(s,direction="RL",**kwargs):
705  """
706  Make a drawing of an SX and display it.
707
708  direction   one of "BT", "LR", "TB", "RL"
709  """
710
711  try:  # Check if we have pylab
712    from pylab import imread, imshow,show,figure, axes
713  except:
714    # We don't have pylab, so just write out to file
715    print("casadi.tools.graph.dotdraw: no pylab detected, will not show drawing on screen.")
716    dotgraph(s,direction=direction,**kwargs).write_ps("temp.ps")
717    return
718
719  if hasattr(show,'__class__') and show.__class__.__name__=='PylabShow':
720    # catch pyreport case, so we have true vector graphics
721    figure_name = '%s%d.%s' % ( show.basename, len(show.figure_list), show.figure_extension )
722    show.figure_list += (figure_name, )
723    dotgraph(s,direction=direction,**kwargs).write_pdf(figure_name)
724    print("Here goes figure %s (dotdraw)" % figure_name)
725  else:
726    # Matplotlib does not allow to display vector graphics on screen,
727    # so we fall back to png
728    temp="_temp.png"
729    dotgraph(s,direction=direction,**kwargs).write_png(temp)
730    im = imread(temp)
731    figure()
732    ax = axes([0,0,1,1], frameon=False)
733    ax.set_axis_off()
734    imshow(im)
735    show()
736