1# 2# This file is part of CasADi. 3# 4# CasADi -- A symbolic framework for dynamic optimization. 5# Copyright (C) 2010-2014 Joel Andersson, Joris Gillis, Moritz Diehl, 6# K.U. Leuven. All rights reserved. 7# Copyright (C) 2011-2014 Greg Horn 8# 9# CasADi is free software; you can redistribute it and/or 10# modify it under the terms of the GNU Lesser General Public 11# License as published by the Free Software Foundation; either 12# version 3 of the License, or (at your option) any later version. 13# 14# CasADi is distributed in the hope that it will be useful, 15# but WITHOUT ANY WARRANTY; without even the implied warranty of 16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 17# Lesser General Public License for more details. 18# 19# You should have received a copy of the GNU Lesser General Public 20# License along with CasADi; if not, write to the Free Software 21# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 22# 23# 24from casadi import SX, MX, print_operator 25import casadi as C 26 27try: 28 from pydot import pydot 29except: 30 try: 31 import pydot 32 except: 33 raise Exception("To use the functionality of casadi.tools.graph, you need to have pydot Installed. Try `easy_install pydot`.") 34 35#import ipdb 36 37def hashcompare(self,other): 38 return cmp(hash(self),hash(other)) 39 40 41def equality(self,other): 42 return hash(self)==hash(other) 43 44def getDeps(s): 45 deps = [] 46 if not(hasattr(s,'n_dep')): return deps 47 for k in range(s.n_dep()): 48 d = s.dep(k) 49 d.__class__.__cmp__ = hashcompare 50 d.__class__.__eq__ = equality 51 deps.append(d) 52 return deps 53 54def addDependency(master,slave,dep={},invdep={}): 55 #print master.__hash__() , " => ", slave.__hash__(), " ", master , " => ", slave 56 if master in dep: 57 dep[master].add(slave) 58 else: 59 dep[master]= set([slave]) 60 if slave in invdep: 61 invdep[slave].add(master) 62 else: 63 invdep[slave] = set([master]) 64 65def addDependencies(master,slaves,dep={},invdep={}): 66 for slave in slaves: 67 addDependency(master,slave,dep = dep,invdep = invdep) 68 for slave in slaves: 69 dependencyGraph(slave,dep = dep,invdep = invdep) 70 71def is_leaf(s): 72 return s.is_leaf() 73 #return s.is_scalar(True) and (s.is_constant() or s.is_symbolic()) 74 75def dependencyGraph(s,dep = {},invdep = {}): 76 if isinstance(s,SX): 77 if s.is_scalar(True): 78 if not(is_leaf(s)): 79 addDependencies(s,getDeps(s),dep = dep,invdep = invdep) 80 else: 81 addDependencies(s,s.nonzeros(),dep = dep,invdep = invdep) 82 elif isinstance(s,MX): 83 addDependencies(s,getDeps(s),dep = dep,invdep = invdep) 84 return (dep,invdep) 85 86class DotArtist: 87 sparsitycol = "#eeeeee" 88 def __init__(self,s,dep={},invdep={},graph=None,artists={},**kwargs): 89 self.s = s 90 self.dep = dep 91 self.invdep = invdep 92 self.graph = graph 93 self.artists = artists 94 self.kwargs = kwargs 95 96 def hasPorts(self): 97 return False 98 99 def drawSparsity(self,s,id=None,depid=None,graph=None,nzlabels=None): 100 if id is None: 101 id = str(s.__hash__()) 102 if depid is None: 103 depid = str(s.dep(0).__hash__()) 104 if graph is None: 105 graph = self.graph 106 sp = s.sparsity() 107 deps = getDeps(s) 108 if nzlabels is None: 109 nzlabels = list(map(str,list(range(sp.nnz())))) 110 nzlabelcounter = 0 111 if s.nnz()==s.numel(): 112 graph.add_node(pydot.Node(id,label="%d x %d" % (s.size1(),s.size2()),shape='rectangle',color=self.sparsitycol,style="filled")) 113 else: 114 label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0">' 115 label+="<TR><TD COLSPAN='%d'><font color='#666666'>%s</font></TD></TR>" % (s.size2(), s.dim()) 116 for i in range(s.size1()): 117 label+="<TR>" 118 for j in range(s.size2()): 119 k = sp.get_nz(i,j) 120 if k==-1: 121 label+="<TD>.</TD>" 122 else: 123 label+="<TD PORT='f%d' BGCOLOR='%s'>%s</TD>" % (k,self.sparsitycol,nzlabels[nzlabelcounter]) 124 nzlabelcounter +=1 125 label+="</TR>" 126 label+="</TABLE>>" 127 graph.add_node(pydot.Node(id,label=label,shape='plaintext')) 128 graph.add_edge(pydot.Edge(depid,id)) 129 130class MXSymbolicArtist(DotArtist): 131 def hasPorts(self): 132 return True 133 134 def draw(self): 135 s = self.s 136 graph = self.graph 137 sp = s.sparsity() 138 row = sp.row() 139 col = "#990000" 140 if s.nnz() == s.numel() and s.nnz()==1: 141 # The Matrix grid is represented by a html table with 'ports' 142 graph.add_node(pydot.Node(str(self.s.__hash__())+":f0",label=s.name(),shape='rectangle',color=col)) 143 else: 144 # The Matrix grid is represented by a html table with 'ports' 145 label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" COLOR="%s">' % col 146 label+="<TR><TD COLSPAN='%d'>%s: <font color='#666666'>%s</font></TD></TR>" % (s.size2(),s.name(), s.dim()) 147 for i in range(s.size1()): 148 label+="<TR>" 149 for j in range(s.size2()): 150 k = sp.get_nz(i,j) 151 if k==-1: 152 label+="<TD>.</TD>" 153 else: 154 label+="<TD PORT='f%d' BGCOLOR='#eeeeee'> <font color='#666666'>(%d,%d | %d)</font> </TD>" % (k,i,j,k) 155 label+="</TR>" 156 label+="</TABLE>>" 157 graph.add_node(pydot.Node(str(self.s.__hash__()),label=label,shape='plaintext')) 158 159# class MXMappingArtist(DotArtist): 160# def draw(self): 161# s = self.s 162# graph = self.graph 163# sp = s.sparsity() 164# row = sp.row() 165 166 167# # Note: due to Mapping restructuring, this is no longer efficient code 168# deps = getDeps(s) 169 170# depind = s.depInd() 171# nzmap = sum([s.mapping(i) for i in range(len(deps))]) 172 173# for k,d in enumerate(deps): 174# candidates = map(hash,filter(lambda i: i.isMapping(),self.invdep[d])) 175# candidates.sort() 176# if candidates[0] == hash(s): 177# graph.add_edge(pydot.Edge(str(d.__hash__()),"mapinput" + str(d.__hash__()))) 178 179# graph = pydot.Cluster('clustertest' + str(s.__hash__()), rank='max', label='Mapping') 180# self.graph.add_subgraph(graph) 181 182 183 184# colors = ['#eeeecc','#ccccee','#cceeee','#eeeecc','#eeccee','#cceecc'] 185 186# for k,d in enumerate(deps): 187# spd = d.sparsity() 188# #ipdb.set_trace() 189# # The Matrix grid is represented by a html table with 'ports' 190# candidates = map(hash,filter(lambda i: i.isMapping(),self.invdep[d])) 191# candidates.sort() 192# if candidates[0] == hash(s): 193# label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" COLOR="#0000aa">' 194# if not(d.numel()==1 and d.numel()==d.nnz()): 195# label+="<TR><TD COLSPAN='%d' BGCOLOR='#dddddd'><font>%s</font></TD></TR>" % (d.size2(), d.dim()) 196# for i in range(d.size1()): 197# label+="<TR>" 198# for j in range(d.size2()): 199# kk = spd.get_nz(i,j) 200# if kk==-1: 201# label+="<TD>.</TD>" 202# else: 203# label+="<TD PORT='f%d' BGCOLOR='%s'> <font color='#666666'>%d</font> </TD>" % (kk,colors[k],kk) 204# label+="</TR>" 205# label+="</TABLE>>" 206# graph.add_node(pydot.Node("mapinput" + str(d.__hash__()),label=label,shape='plaintext')) 207# graph.add_edge(pydot.Edge("mapinput" + str(d.__hash__()),str(s.__hash__()))) 208 209# # The Matrix grid is represented by a html table with 'ports' 210# label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0">' 211# if not(s.numel()==1 and s.numel()==s.nnz()): 212# label+="<TR><TD COLSPAN='%d'><font color='#666666'>%s</font></TD></TR>" % (s.size2(), s.dim()) 213# for i in range(s.size1()): 214# label+="<TR>" 215# for j in range(s.size2()): 216# k = sp.get_nz(i,j) 217# if k==-1: 218# label+="<TD>.</TD>" 219# else: 220# label+="<TD PORT='f%d' BGCOLOR='%s'> <font color='#666666'>%d</font> </TD>" % (k,colors[depind[k]],nzmap[k]) 221# label+="</TR>" 222# label+="</TABLE>>" 223# graph.add_node(pydot.Node(str(self.s.__hash__()),label=label,shape='plaintext')) 224 225 226class MXEvaluationArtist(DotArtist): 227 def draw(self): 228 s = self.s 229 graph = self.graph 230 sp = s.sparsity() 231 row = sp.row() 232 233 234 deps = getDeps(s) 235 236 f = s.which_function() 237 238 for k,d in enumerate(deps): 239 graph.add_edge(pydot.Edge(str(d.__hash__()),"funinput" + str(s.__hash__())+ ":f%d" % k,rankdir="LR")) 240 241 graph = pydot.Cluster(str(s.__hash__()), rank='max', label='Function:\n %s' % f.name()) 242 self.graph.add_subgraph(graph) 243 244 s = (" %d inputs: |" % f.n_in()) + " | ".join("<f%d> %d" % (i,i) for i in range(f.n_in())) 245 graph.add_node(pydot.Node("funinput" + str(self.s.__hash__()),label=s,shape='Mrecord')) 246 247 s = (" %d outputs: |" % f.n_out())+ " | ".join("<f%d> %d" % (i,i) for i in range(f.n_out())) 248 graph.add_node(pydot.Node(str(self.s.__hash__()),label=s,shape='Mrecord')) 249 250 251class MXConstantArtist(DotArtist): 252 def hasPorts(self): 253 return True 254 def draw(self): 255 s = self.s 256 graph = self.graph 257 sp = s.sparsity() 258 row = sp.row() 259 M = s.to_DM() 260 col = "#009900" 261 if s.nnz() == s.numel() and s.nnz() == 1: 262 graph.add_node(pydot.Node(str(self.s.__hash__())+":f0",label=str(M[0,0]),shape='rectangle',color=col)) 263 else: 264 # The Matrix grid is represented by a html table with 'ports' 265 label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" COLOR="%s">' % col 266 label+="<TR><TD COLSPAN='%d'><font color='#666666'>%s</font></TD></TR>" % (s.size2(), s.dim()) 267 if not "max_numel" in self.kwargs or s.numel() < self.kwargs["max_numel"]: 268 for i in range(s.size1()): 269 label+="<TR>" 270 for j in range(s.size2()): 271 k = sp.get_nz(i,j) 272 if k==-1: 273 label+="<TD>.</TD>" 274 else: 275 label+="<TD PORT='f%d' BGCOLOR='#eeeeee'> %s </TD>" % (k,M[i,j]) 276 label+="</TR>" 277 label+="</TABLE>>" 278 graph.add_node(pydot.Node(str(self.s.__hash__()),label=label,shape='plaintext')) 279 280class MXGenericArtist(DotArtist): 281 def draw(self): 282 k = self.s 283 graph = self.graph 284 dep = getDeps(k) 285 286 show_sp = not(all([d.sparsity()==k.sparsity() for d in dep])) 287 288 if show_sp: 289 op = "op" 290 self.drawSparsity(k,depid=op + str(k.__hash__())) 291 else: 292 op = "" 293 294 if len(dep)>1: 295 # Non-commutative operators are represented by 'record' shapes. 296 # The dependencies have different 'ports' where arrows should arrive. 297 s = print_operator(self.s,["| <f%d> | " %i for i in range(len(dep))]) 298 if s.startswith("(|") and s.endswith("|)"): 299 s=s[2:-2] 300 301 graph.add_node(pydot.Node(op + str(k.__hash__()),label=s,shape='Mrecord')) 302 for i,n in enumerate(dep): 303 graph.add_edge(pydot.Edge(str(n.__hash__()),op + str(k.__hash__())+":f%d" % i)) 304 else: 305 s = print_operator(k,["."]) 306 self.graph.add_node(pydot.Node(op + str(k.__hash__()),label=s,shape='oval')) 307 for i,n in enumerate(dep): 308 self.graph.add_edge(pydot.Edge(str(n.__hash__()),op + str(k.__hash__()))) 309 310class MXGetNonzerosArtist(DotArtist): 311 def draw(self): 312 s = self.s 313 graph = self.graph 314 n = getDeps(s)[0] 315 316 317 show_sp = not(s.nnz() == s.numel() and s.nnz() == 1) 318 319 if show_sp: 320 op = "op" 321 self.drawSparsity(s,depid=op + str(s.__hash__())) 322 else: 323 op = "" 324 325 sp = s.sparsity() 326 row = sp.row() 327 M = s.mapping() 328 col = "#333333" 329 if s.nnz() == s.numel() and s.nnz() == 1: 330 graph.add_node(pydot.Node(op+str(s.__hash__())+":f0",label="[%s]" % str(M[0,0]),shape='rectangle',style="filled",fillcolor='#eeeeff')) 331 else: 332 # The Matrix grid is represented by a html table with 'ports' 333 label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" COLOR="%s">' % col 334 label+="<TR><TD COLSPAN='%d' PORT='entry'>getNonzeros</TD></TR>" % (s.size2()) 335 if not "max_nnz" in self.kwargs or s.nnz() < self.kwargs["max_nnz"]: 336 for i in range(s.size1()): 337 label+="<TR>" 338 for j in range(s.size2()): 339 k = sp.get_nz(i,j) 340 if k==-1: 341 label+="<TD>.</TD>" 342 else: 343 label+="<TD PORT='f%d' BGCOLOR='#eeeeff'> %s </TD>" % (k,M[i,j]) 344 label+="</TR>" 345 label+="</TABLE>>" 346 graph.add_node(pydot.Node(op+str(s.__hash__()),label=label,shape='plaintext')) 347 self.graph.add_edge(pydot.Edge(str(n.__hash__()),op+str(s.__hash__()))) 348 349class MXSetNonzerosArtist(DotArtist): 350 def draw(self): 351 s = self.s 352 graph = self.graph 353 entry = getDeps(s)[0] 354 target = getDeps(s)[1] 355 356 357 show_sp = not(all([d.sparsity()==s.sparsity() for d in getDeps(s)])) 358 359 if show_sp: 360 op = "op" 361 self.drawSparsity(s,depid=op + str(s.__hash__())) 362 else: 363 op = "" 364 365 sp = target.sparsity() 366 row = sp.row() 367 M = list(s.mapping()) 368 Mk = 0 369 col = "#333333" 370 371 # The Matrix grid is represented by a html table with 'ports' 372 label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" COLOR="%s">' % col 373 label+="<TR><TD COLSPAN='%d' PORT='entry'>setNonzeros</TD></TR>" % (s.size2()) 374 for i in range(s.size1()): 375 label+="<TR>" 376 for j in range(s.size2()): 377 k = entry.sparsity().get_nz(i,j) 378 if k==-1 or Mk>= len(M) or k != M[Mk]: 379 label+="<TD>.</TD>" 380 if Mk< len(M)-1 and M[Mk]==-1 and k!=-1: Mk+=1 381 else: 382 label+="<TD PORT='f%d' BGCOLOR='#eeeeff'> %s </TD>" % (Mk,Mk) 383 Mk+=1 384 label+="</TR>" 385 label+="</TABLE>>" 386 graph.add_node(pydot.Node(op+str(s.__hash__()),label=label,shape='plaintext')) 387 self.graph.add_edge(pydot.Edge(str(entry.__hash__()),op+str(s.__hash__())+':entry')) 388 self.graph.add_edge(pydot.Edge(str(target.__hash__()),op+str(s.__hash__()))) 389 390 391class MXAddNonzerosArtist(DotArtist): 392 def draw(self): 393 s = self.s 394 graph = self.graph 395 entry = getDeps(s)[0] 396 target = getDeps(s)[1] 397 398 show_sp = not(all([d.sparsity()==s.sparsity() for d in getDeps(s)])) 399 400 if show_sp: 401 op = "op" 402 self.drawSparsity(s,depid=op + str(s.__hash__())) 403 else: 404 op = "" 405 406 sp = target.sparsity() 407 row = sp.row() 408 M = list(s.mapping()) 409 Mk = 0 410 col = "#333333" 411 412 # The Matrix grid is represented by a html table with 'ports' 413 label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" COLOR="%s">' % col 414 label+="<TR><TD COLSPAN='%d' PORT='entry'>addNonzeros</TD></TR>" % (s.size2()) 415 for i in range(s.size1()): 416 label+="<TR>" 417 for j in range(s.size2()): 418 k = sp.get_nz(i,j) 419 if k==-1 or Mk>= len(M) or k != M[Mk]: 420 label+="<TD>.</TD>" 421 if Mk< len(M)-1 and M[Mk]==-1 and k!=-1: Mk+=1 422 else: 423 label+="<TD PORT='f%d' BGCOLOR='#eeeeff'> %s </TD>" % (Mk,Mk) 424 Mk+=1 425 label+="</TR>" 426 label+="</TABLE>>" 427 graph.add_node(pydot.Node(op+str(s.__hash__()),label=label,shape='plaintext')) 428 self.graph.add_edge(pydot.Edge(str(entry.__hash__()),op+str(s.__hash__())+':entry')) 429 self.graph.add_edge(pydot.Edge(str(target.__hash__()),op+str(s.__hash__()))) 430 431 432class MXOperationArtist(DotArtist): 433 def draw(self): 434 k = self.s 435 graph = self.graph 436 dep = getDeps(k) 437 438 show_sp = True 439 440 if k.is_unary() and dep[0].sparsity()==k.sparsity(): 441 show_sp = False 442 if k.is_binary() and dep[0].sparsity()==k.sparsity() and dep[1].sparsity()==k.sparsity(): 443 show_sp = False 444 445 if show_sp: 446 op = "op" 447 self.drawSparsity(k,depid=op + str(k.__hash__())) 448 else: 449 op = "" 450 451 if not(k.is_commutative()): 452 # Non-commutative operators are represented by 'record' shapes. 453 # The dependencies have different 'ports' where arrows should arrive. 454 s = print_operator(self.s,["| <f0> | ", " | <f1> |"]) 455 if s.startswith("(|") and s.endswith("|)"): 456 s=s[2:-2] 457 458 graph.add_node(pydot.Node(op + str(k.__hash__()),label=s,shape='Mrecord')) 459 for i,n in enumerate(dep): 460 graph.add_edge(pydot.Edge(str(n.__hash__()),op + str(k.__hash__())+":f%d" % i)) 461 else: 462 # Commutative operators can be represented more compactly as 'oval' shapes. 463 s = print_operator(k,[".", "."]) 464 if s.startswith("(.") and s.endswith(".)"): 465 s=s[2:-2] 466 if s.startswith("(") and s.endswith(")"): 467 s=s[1:-1] 468 self.graph.add_node(pydot.Node(op + str(k.__hash__()),label=s,shape='oval')) 469 for i,n in enumerate(dep): 470 self.graph.add_edge(pydot.Edge(str(n.__hash__()),op + str(k.__hash__()))) 471 472class MXIfTestArtist(DotArtist): 473 def draw(self): 474 k = self.s 475 graph = self.graph 476 dep = getDeps(k) 477 478 show_sp = True 479 480 s = "<f0> ? | <f1> true" 481 482 graph.add_node(pydot.Node(str(k.__hash__()),label=s,shape='Mrecord')) 483 for i,n in enumerate(dep): 484 graph.add_edge(pydot.Edge(str(n.__hash__()),str(k.__hash__())+":f%d" % i)) 485 486class MXDensificationArtist(DotArtist): 487 def draw(self): 488 k = self.s 489 graph = self.graph 490 dep = getDeps(k) 491 492 self.graph.add_node(pydot.Node(str(k.__hash__()),label="densify(.)",shape='oval')) 493 self.graph.add_edge(pydot.Edge(str(dep[0].__hash__()),str(k.__hash__()))) 494 495class MXNormArtist(DotArtist): 496 def draw(self): 497 k = self.s 498 graph = self.graph 499 dep = getDeps(k) 500 s = print_operator(k,[".", "."]) 501 self.graph.add_node(pydot.Node(str(k.__hash__()),label=s,shape='oval')) 502 self.graph.add_edge(pydot.Edge(str(dep[0].__hash__()),str(k.__hash__()))) 503 504class MXEvaluationOutputArtist(DotArtist): 505 def draw(self): 506 k = self.s 507 508 self.drawSparsity(k,depid=str(hash(k.dep(0))) + ":f%d" % k.which_output()) 509 510 511class MXMultiplicationArtist(DotArtist): 512 def draw(self): 513 k = self.s 514 graph = self.graph 515 dep = getDeps(k) 516 517 # Non-commutative operators are represented by 'record' shapes. 518 # The dependencies have different 'ports' where arrows should arrive. 519 s = "mul(| <f0> | , | <f1> | )" 520 521 graph.add_node(pydot.Node(str(k.__hash__()),label=s,shape='Mrecord')) 522 for i,n in enumerate(dep): 523 graph.add_edge(pydot.Edge(str(n.__hash__()),str(k.__hash__())+":f%d" % i)) 524 525class SXArtist(DotArtist): 526 def draw(self): 527 s = self.s 528 graph = self.graph 529 sp = s.sparsity() 530 row = sp.row() 531 532 # The Matrix grid is represented by a html table with 'ports' 533 label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0">' 534 for i in range(s.size1()): 535 label+="<TR>" 536 for j in range(s.size2()): 537 k = sp.get_nz(i,j) 538 if k==-1: 539 label+="<TD>.</TD>" 540 else: 541 sx = s.nz[k] 542 if self.shouldEmbed(sx): 543 label+="<TD BGCOLOR='#eeeeee'>%s</TD>" % str(sx) 544 else: 545 self.graph.add_edge(pydot.Edge(str(sx.__hash__()),"%s:f%d" % (str(self.s.__hash__()), k))) 546 label+="<TD PORT='f%d' BGCOLOR='#eeeeee'> <font color='#666666'>(%d,%d|%d)</font> </TD>" % (k,i,j,k) 547 label+="</TR>" 548 label+="</TABLE>>" 549 graph.add_node(pydot.Node(str(self.s.__hash__()),label=label,shape='plaintext')) 550 551 def shouldEmbed(self,sx): 552 return len(self.invdep[sx]) == 1 and sx.is_leaf() 553 554class SXLeafArtist(DotArtist): 555 def draw(self): 556 if len(self.invdep[self.s]) == 1: 557 master = list(self.invdep[self.s])[0] 558 if hasattr(self.artists[master],'shouldEmbed'): 559 if self.artists[master].shouldEmbed(self.s): 560 return 561 style = "solid" # Symbolic nodes are represented box'es 562 if self.s.is_constant(): 563 style = "bold" # Constants are represented by bold box'es 564 self.graph.add_node(pydot.Node(str(self.s.__hash__()),label=str(self.s),shape="box",style=style)) 565 566class SXNonLeafArtist(DotArtist): 567 def draw(self): 568 k = self.s 569 graph = self.graph 570 dep = getDeps(k) 571 if not(k.is_commutative()): 572 # Non-commutative operators are represented by 'record' shapes. 573 # The dependencies have different 'ports' where arrows should arrive. 574 if len(dep)==2: 575 s = print_operator(self.s,["| <f0> | ", " | <f1> |"]) 576 else: 577 s = print_operator(self.s,["| <f0> | "]) 578 if s.startswith("(|") and s.endswith("|)"): 579 s=s[2:-2] 580 581 graph.add_node(pydot.Node(str(k.__hash__()),label=s,shape='Mrecord')) 582 for i,n in enumerate(dep): 583 graph.add_edge(pydot.Edge(str(n.__hash__()),str(k.__hash__())+":f%d" % i)) 584 else: 585 # Commutative operators can be represented more compactly as 'oval' shapes. 586 s = print_operator(k,[".", "."]) 587 if s.startswith("(.") and s.endswith(".)"): 588 s=s[2:-2] 589 if s.startswith("(") and s.endswith(")"): 590 s=s[1:-1] 591 self.graph.add_node(pydot.Node(str(k.__hash__()),label=s,shape='oval')) 592 for i,n in enumerate(dep): 593 self.graph.add_edge(pydot.Edge(str(n.__hash__()),str(k.__hash__()))) 594 595 596 597def createArtist(node,dep={},invdep={},graph=None,artists={},**kwargs): 598 if isinstance(node,SX): 599 if node.is_scalar(True): 600 if is_leaf(node): 601 return SXLeafArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs) 602 else: 603 return SXNonLeafArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs) 604 else: 605 return SXArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs) 606 607 608 elif isinstance(node,MX): 609 if node.is_symbolic(): 610 return MXSymbolicArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs) 611 elif node.is_binary() or node.is_unary(): 612 return MXOperationArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs) 613 elif node.is_constant(): 614 return MXConstantArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs) 615 elif node.is_call(): 616 return MXEvaluationArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs) 617 elif node.is_output(): 618 return MXEvaluationOutputArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs) 619 elif node.is_norm(): 620 return MXNormArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs) 621 elif node.is_op(C.OP_GETNONZEROS): 622 return MXGetNonzerosArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs) 623 elif node.is_op(C.OP_SETNONZEROS): 624 return MXSetNonzerosArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs) 625 elif node.is_op(C.OP_ADDNONZEROS): 626 return MXAddNonzerosArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs) 627 else: 628 return MXGenericArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs) 629 else: 630 raise Exception("Cannot create artist for %s" % str(type(s))) 631 632def dotgraph(s,direction="BT",**kwargs): 633 """ 634 Creates and returns a pydot graph structure that represents an SX. 635 636 direction one of "BT", "LR", "TB", "RL" 637 """ 638 639 try: 640 def getHashSX(e): 641 if e.is_scalar(True): 642 try: 643 return e.element_hash() 644 except: 645 return SX__hash__backup(e) 646 else: 647 return 0 648 649 SX__hash__backup = SX.__hash__ 650 SX.__hash__ = getHashSX 651 652 # Get the dependencies and inverse dependencies in a dict 653 dep, invdep = dependencyGraph(s,{},{}) 654 655 allnodes = set(dep.keys()).union(set(invdep.keys())) 656 657 #print "a", set(dep.keys()), [i.__hash__() for i in dep.keys()] 658 #print "b", set(invdep.keys()), [i.__hash__() for i in invdep.keys()] 659 #print "allnodes", allnodes, [i.__hash__() for i in allnodes] 660 661 #return None 662 663 artists = {} 664 665 graph = pydot.Dot('G', graph_type='digraph',rankdir=direction) 666 667 for node in allnodes: 668 artists[node] = createArtist(node,dep=dep,invdep=invdep,graph=graph,artists=artists,**kwargs) 669 670 for artist in artists.values(): 671 if artist is None: continue 672 artist.draw() 673 674 open('source.dot','w').write(graph.to_string()) 675 finally: 676 SX.__hash__ = SX__hash__backup 677 return graph 678 679 680def dotsave(s,format='ps',filename="temp",direction="RL",**kwargs): 681 """ 682 Make a drawing of an SX and save it. 683 684 format can be one of: 685 dot canon cmap cmapx cmapx_np dia dot fig gd gd2 gif hpgl imap imap_np 686 ismap jpe jpeg jpg mif mp pcl pdf pic plain plain-ext png ps ps2 raw 687 svg svgz vml vmlz vrml vtx wbmp xdot xlib 688 689 direction one of "BT", "LR", "TB", "RL" 690 691 """ 692 g = dotgraph(s,direction=direction,**kwargs) 693 if hasattr(g,'write_'+format): 694 getattr(g,'write_'+format)(filename) 695 else: 696 s = "Unknown format '%s'. Please pick one of the following:\n" % format 697 l = ['dot'] 698 for n in dir(g): 699 if n.startswith("write_"): 700 l.append(n[6:]) 701 s+= " ".join(l) 702 raise Exception(s) 703 704def dotdraw(s,direction="RL",**kwargs): 705 """ 706 Make a drawing of an SX and display it. 707 708 direction one of "BT", "LR", "TB", "RL" 709 """ 710 711 try: # Check if we have pylab 712 from pylab import imread, imshow,show,figure, axes 713 except: 714 # We don't have pylab, so just write out to file 715 print("casadi.tools.graph.dotdraw: no pylab detected, will not show drawing on screen.") 716 dotgraph(s,direction=direction,**kwargs).write_ps("temp.ps") 717 return 718 719 if hasattr(show,'__class__') and show.__class__.__name__=='PylabShow': 720 # catch pyreport case, so we have true vector graphics 721 figure_name = '%s%d.%s' % ( show.basename, len(show.figure_list), show.figure_extension ) 722 show.figure_list += (figure_name, ) 723 dotgraph(s,direction=direction,**kwargs).write_pdf(figure_name) 724 print("Here goes figure %s (dotdraw)" % figure_name) 725 else: 726 # Matplotlib does not allow to display vector graphics on screen, 727 # so we fall back to png 728 temp="_temp.png" 729 dotgraph(s,direction=direction,**kwargs).write_png(temp) 730 im = imread(temp) 731 figure() 732 ax = axes([0,0,1,1], frameon=False) 733 ax.set_axis_off() 734 imshow(im) 735 show() 736