1# (c) Copyright by Pierre-Henri Wuillemin, UPMC, 2017 2# (pierre-henri.wuillemin@lip6.fr) 3# Permission to use, copy, modify, and distribute this 4# software and its documentation for any purpose and 5# without fee or royalty is hereby granted, provided 6# that the above copyright notice appear in all copies 7# and that both that copyright notice and this permission 8# notice appear in supporting documentation or portions 9# thereof, including modifications, that you make. 10# THE AUTHOR P.H. WUILLEMIN DISCLAIMS ALL WARRANTIES 11# WITH REGARD TO THIS SOFTWARE, INCLUDING ALL IMPLIED 12# WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT 13# SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT 14# OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER 15# RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER 16# IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 17# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE 18# OR PERFORMANCE OF THIS SOFTWARE! 19 20""" 21tools for BN in jupyter notebook 22""" 23 24import time 25import re 26import sys 27 28# fix DeprecationWarning of base64.encodestring() 29try: 30 from base64 import encodebytes 31except ImportError: # 3+ 32 from base64 import encodestring as encodebytes 33 34import io 35import base64 36 37import matplotlib as mpl 38import matplotlib.pyplot as plt 39 40try: 41 from matplotlib_inline.backend_inline import set_matplotlib_formats as set_matplotlib_formats 42except ImportError: # because of python 2.7, matplotlib-inline cannot be part of requirements.txt 43 def set_matplotlib_formats(*args, **kwargs): 44 # dummy version when no matplotlib_inline package 45 if sys.version[0] == "3": 46 print("** pyAgrum** For better visualizations, please install matplotlib-inline.") 47 pass 48 49import numpy as np 50import pydotplus as dot 51 52import IPython.core.display 53import IPython.core.pylabtools 54import IPython.display 55 56import pyAgrum as gum 57from pyAgrum.lib.bn2graph import BN2dot, BNinference2dot 58from pyAgrum.lib.cn2graph import CN2dot, CNinference2dot 59from pyAgrum.lib.id2graph import ID2dot, LIMIDinference2dot 60from pyAgrum.lib.mn2graph import MN2UGdot, MNinference2UGdot 61from pyAgrum.lib.mn2graph import MN2FactorGraphdot, MNinference2FactorGraphdot 62from pyAgrum.lib.bn_vs_bn import GraphicalBNComparator 63from pyAgrum.lib.proba_histogram import proba2histo, probaMinMaxH 64 65from pyAgrum.lib._colors import setDarkTheme,setLightTheme,getBlackInTheme 66from pyAgrum.lib._colors import forDarkTheme,forLightTheme # obsolete since 0.21.0 67 68import pyAgrum.lib._colors as gumcols 69 70 71class FlowLayout(object): 72 """" 73 A class / object to display plots in a horizontal / flow layout below a cell 74 75 based on : https://stackoverflow.com/questions/21754976/ipython-notebook-arrange-plots-horizontally 76 """ 77 78 def __init__(self): 79 self.clear() 80 81 def clear(self): 82 """ 83 clear the flow 84 """ 85 # string buffer for the HTML: initially some CSS; images to be appended 86 self.sHtml = """ 87 <style> 88 .floating-box { 89 display: inline-block; 90 margin: 7px; 91 padding : 3px; 92 border: 2px solid #FFFFFF; 93 valign:middle; 94 background-color: #FDFDFD; 95 } 96 </style> 97 """ 98 99 def _getTitle(self, title): 100 if title == "": 101 return "" 102 return f"<br><center><small><em>{title}</em></small></center>" 103 104 def add_html(self, html, title=""): 105 """ 106 add an html element in the row 107 """ 108 self.sHtml += f'<div class="floating-box">{html}{self._getTitle(title)}</div>' 109 110 def add_separator(self, size=3): 111 """ 112 add a (poor) separation between elements in a row 113 """ 114 self.add_html(" " * size) 115 116 def add_plot(self, oAxes, title=""): 117 """ 118 Add a PNG representation of a Matplotlib Axes object 119 """ 120 Bio = io.BytesIO() # bytes buffer for the plot 121 fig = oAxes.get_figure() 122 fig.canvas.print_png(Bio) # make a png of the plot in the buffer 123 124 # encode the bytes as string using base 64 125 sB64Img = base64.b64encode(Bio.getvalue()).decode() 126 self.sHtml += f'<div class="floating-box"><img src="data:image/png;base64,{sB64Img}\n">{self._getTitle(title)}</div>' 127 plt.close() 128 129 def new_line(self): 130 """ 131 add a breakline (a new row) 132 """ 133 self.sHtml += '<br/>' 134 135 def html(self): 136 """ 137 Returns its content as HTML object 138 """ 139 return IPython.display.HTML(self.sHtml) 140 141 def display(self): 142 """ 143 Display the accumulated HTML 144 """ 145 IPython.display.display(self.html()) 146 self.clear() 147 148 def row(self, *args, captions=None): 149 self.clear() 150 for i, arg in enumerate(args): 151 if captions is None: 152 t = "" 153 else: 154 t = captions[i] 155 156 if hasattr(arg, "get_figure"): 157 self.add_plot(arg, title=t) 158 elif hasattr(arg, "_repr_html_"): 159 self.add_html(arg._repr_html_(), title=t) 160 else: 161 self.add_html(arg, title=t) 162 163 self.display() 164 165 166flow = FlowLayout() 167 168 169def configuration(): 170 """ 171 Display the collection of dependance and versions 172 """ 173 from collections import OrderedDict 174 import sys 175 import os 176 177 packages = OrderedDict() 178 packages["OS"] = "%s [%s]" % (os.name, sys.platform) 179 packages["Python"] = sys.version 180 packages["IPython"] = IPython.__version__ 181 packages["MatPlotLib"] = mpl.__version__ 182 packages["Numpy"] = np.__version__ 183 packages["pyAgrum"] = gum.__version__ 184 185 res = "<table width='100%'><tr><th>Library</th><th>Version</th></tr>" 186 187 for name in packages: 188 res += "<tr><td>%s</td><td>%s</td></tr>" % (name, packages[name]) 189 190 res += "</table><div align='right'><small>%s</small></div>" % time.strftime( 191 '%a %b %d %H:%M:%S %Y %Z' 192 ) 193 194 IPython.display.display(IPython.display.HTML(res)) 195 196 197def __insertLinkedSVGs(mainSvg): 198 re_buggwhitespace = re.compile(r"(<image [^>]*>)") 199 re_images = re.compile(r"(<image [^>]*>)") 200 re_xlink = re.compile(r"xlink:href=\"([^\"]*)") 201 re_viewbox = re.compile(r"(viewBox=\"[^\"]*\")") 202 203 # analyze mainSvg (find the secondary svgs) 204 __fragments = {} 205 for img in re.finditer(re_images, mainSvg): 206 # print(img) 207 secondarySvg = re.findall(re_xlink, img.group(1))[0] 208 content = "" 209 with open(secondarySvg, encoding='utf8') as f: 210 inSvg = False 211 for line in f: 212 if line[0:4] == "<svg": 213 inSvg = True 214 viewBox = re.findall(re_viewbox, line)[0] 215 # print("VIEWBOX {}".format(viewBox)) 216 elif inSvg: 217 content += line 218 __fragments[secondarySvg] = (viewBox, content) 219 220 if len(__fragments) > 0: 221 # replace image tags by svg tags 222 img2svg = re.sub(r"<image ([^>]*)/>", "<svg \g<1>>", mainSvg) 223 224 # insert secondaries into main 225 def ___insertSecondarySvgs(matchObj): 226 vb, code = __fragments[matchObj.group(1)] 227 return vb + matchObj.group(2) + code 228 229 mainSvg = re.sub(r'xlink:href="([^"]*)"(.*>)', 230 ___insertSecondarySvgs, img2svg 231 ) 232 233 # remove buggy white-space (for notebooks) 234 mainSvg = mainSvg.replace("white-space:pre;", "") 235 return mainSvg 236 237 238def _reprGraph(gr, size, asString, format=None): 239 """ 240 repr a pydot graph in a notebook 241 242 :param string size : size of the rendered graph 243 :param boolean asString : display the graph or return a string containing the corresponding HTML fragment 244 """ 245 if size is not None: 246 gr.set_size(size) 247 248 if format is None: 249 format = gum.config["notebook", "graph_format"] 250 251 if format == "svg": 252 gsvg = IPython.display.SVG(__insertLinkedSVGs(gr.create_svg().decode('utf-8'))) 253 if asString: 254 return gsvg.data 255 else: 256 IPython.display.display(gsvg) 257 else: 258 i = IPython.core.display.Image(format="png", data=gr.create_png()) 259 if asString: 260 return f'<img style="margin:0" src="data:image/png;base64,{encodebytes(i.data).decode()}"/>' 261 else: 262 IPython.core.display.display_png(i) 263 264 265def showGraph(gr, size=None): 266 """ 267 show a pydot graph in a notebook 268 269 :param gr: pydot graph 270 :param size: size of the rendered graph 271 :return: the representation of the graph 272 """ 273 if size is None: 274 size = gum.config["notebook", "default_graph_size"] 275 276 return _reprGraph(gr, size, asString=False) 277 278 279def getGraph(gr, size=None): 280 """ 281 get a HTML string representation of pydot graph 282 283 :param gr: pydot graph 284 :param size: size of the rendered graph 285 :param format: render as "png" or "svg" 286 :return: the HTML representation of the graph as a string 287 """ 288 if size is None: 289 size = gum.config["notebook", "default_graph_size"] 290 return _reprGraph(gr, size, asString=True) 291 292 293def _from_dotstring(dotstring): 294 g = dot.graph_from_dot_data(dotstring) 295 g.set_bgcolor("transparent") 296 for e in g.get_edges(): 297 if e.get_color() is None: 298 e.set_color(getBlackInTheme()) 299 for n in g.get_nodes(): 300 if n.get_color() is None: 301 n.set_color(getBlackInTheme()) 302 if n.get_fontcolor() is None: 303 n.set_fontcolor(getBlackInTheme()) 304 return g 305 306 307def showDot(dotstring, size=None): 308 """ 309 show a dot string as a graph 310 311 :param dotstring: dot string 312 :param size: size of the rendered graph 313 :return: the representation of the graph 314 """ 315 if size is None: 316 size = gum.config["notebook", "default_graph_size"] 317 return showGraph(_from_dotstring(dotstring), size) 318 319 320def getDot(dotstring, size=None): 321 """ 322 get a dot string as a HTML string 323 324 :param dotstring: dot string 325 :param size: size of the rendered graph 326 :param format: render as "png" or "svg" 327 :param bg: color for background 328 :return: the HTML representation of the graph 329 """ 330 if size is None: 331 size = gum.config["notebook", "default_graph_size"] 332 333 return getGraph(_from_dotstring(dotstring), size) 334 335 336def getBNDiff(bn1, bn2, size=None): 337 """ get a HTML string representation of a graphical diff between the arcs of _bn1 (reference) with those of _bn2. 338 339 * full black line: the arc is common for both 340 * full red line: the arc is common but inverted in _bn2 341 * dotted black line: the arc is added in _bn2 342 * dotted red line: the arc is removed in _bn2 343 344 :param BayesNet bn1: referent model for the comparison 345 :param BayesNet bn2: bn compared to the referent model 346 :param size: size of the rendered graph 347 """ 348 if size is None: 349 size = gum.config["notebook", "default_graph_size"] 350 cmp = GraphicalBNComparator(bn1, bn2) 351 return getGraph(cmp.dotDiff(), size) 352 353 354def showBNDiff(bn1, bn2, size=None): 355 """ show a graphical diff between the arcs of _bn1 (reference) with those of _bn2. 356 357 * full black line: the arc is common for both 358 * full red line: the arc is common but inverted in _bn2 359 * dotted black line: the arc is added in _bn2 360 * dotted red line: the arc is removed in _bn2 361 362 :param BayesNet bn1: referent model for the comparison 363 :param BayesNet bn2: bn compared to the referent model 364 :param size: size of the rendered graph 365 """ 366 if size is None: 367 size = gum.config["notebook", "default_graph_size"] 368 cmp = GraphicalBNComparator(bn1, bn2) 369 showGraph(cmp.dotDiff(), size) 370 371 372def showInformation(*args, **kwargs): 373 print( 374 "[pyAgrum] pyAgrum.lib.notebook.showInformation is deprecated since 0.20.2. Please use pyAgrum.lib.explain.showInfomation instead." 375 ) 376 import pyAgrum.lib.explain as explain 377 explain.showInformation(*args, **kwargs) 378 379 380def showJunctionTree(bn, withNames=True, size=None): 381 """ 382 Show a junction tree 383 384 :param bn: the Bayesian network 385 :param boolean withNames: display the variable names or the node id in the clique 386 :param size: size of the rendered graph 387 :return: the representation of the graph 388 """ 389 if size is None: 390 size = gum.config["notebook", "default_graph_size"] 391 392 jtg = gum.JunctionTreeGenerator() 393 jt = jtg.junctionTree(bn) 394 395 jt._engine = jtg 396 jtg._model = bn 397 398 if withNames: 399 return showDot(jt.toDotWithNames(bn), size) 400 else: 401 return showDot(jt.toDot(), size) 402 403 404def getJunctionTree(bn, withNames=True, size=None): 405 """ 406 get a HTML string for a junction tree (more specifically a join tree) 407 408 :param bn: the Bayesian network 409 :param boolean withNames: display the variable names or the node id in the clique 410 :param size: size of the rendered graph 411 :return: the HTML representation of the graph 412 """ 413 if size is None: 414 size = gum.config["notebook", "junctiontree_graph_size"] 415 416 jtg = gum.JunctionTreeGenerator() 417 jt = jtg.junctionTree(bn) 418 419 jt._engine = jtg 420 jtg._model = bn 421 422 if withNames: 423 return getDot(jt.toDotWithNames(bn), size) 424 else: 425 return getDot(jt.toDot(), size) 426 427 428def showProba(p, scale=1.0): 429 """ 430 Show a mono-dim Potential 431 432 :param p: the mono-dim Potential 433 :param scale: the scale (zoom) 434 """ 435 fig = proba2histo(p, scale) 436 set_matplotlib_formats(gum.config["notebook", "graph_format"]) 437 plt.show() 438 439 440def _getMatplotFig(fig): 441 bio = io.BytesIO() # bytes buffer for the plot 442 fig.savefig(bio, format='png', bbox_inches='tight') # .canvas.print_png(bio) # make a png of the plot in the buffer 443 444 # encode the bytes as string using base 64 445 sB64Img = base64.b64encode(bio.getvalue()).decode() 446 res = f'<img src="data:image/png;base64,{sB64Img}\n">' 447 plt.close() 448 return res 449 450 451def getProba(p, scale=1.0): 452 """ 453 get a mono-dim Potential as html (png) img 454 455 :param p: the mono-dim Potential 456 :param scale: the scale (zoom) 457 """ 458 set_matplotlib_formats(gum.config["notebook", "graph_format"]) 459 return _getMatplotFig(proba2histo(p, scale)) 460 461 462def showProbaMinMax(pmin, pmax, scale=1.0): 463 """ 464 Show a bi-Potential (min,max) 465 466 :param pmin: the mono-dim Potential for min values 467 :param pmax: the mono-dim Potential for max values 468 :param scale: the scale (zoom) 469 """ 470 fig = probaMinMaxH(pmin, pmax, scale) 471 set_matplotlib_formats(gum.config["notebook", "graph_format"]) 472 plt.show() 473 474 475def getProbaMinMax(pmin, pmax, scale=1.0): 476 """ 477 get a bi-Potential (min,max) as html (png) img 478 479 :param pmin: the mono-dim Potential for min values 480 :param pmax: the mono-dim Potential for max values 481 :param scale: the scale (zoom) 482 """ 483 set_matplotlib_formats(gum.config["notebook", "graph_format"]) 484 return _getMatplotFig(probaMinMaxH(pmin, pmax, scale)) 485 486 487def getPosterior(bn, evs, target): 488 """ 489 shortcut for proba2histo(gum.getPosterior(bn,evs,target)) 490 491 :param bn: the BayesNet 492 :type bn: gum.BayesNet 493 :param evs: map of evidence 494 :type evs: dict(str->int) 495 :param target: name of target variable 496 :type target: str 497 :return: the matplotlib graph 498 """ 499 fig = proba2histo(gum.getPosterior(bn, evs=evs, target=target)) 500 plt.close() 501 return _getMatplotFig(fig) 502 503 504def showPosterior(bn, evs, target): 505 """ 506 shortcut for showProba(gum.getPosterior(bn,evs,target)) 507 508 :param bn: the BayesNet 509 :param evs: map of evidence 510 :param target: name of target variable 511 """ 512 showProba(gum.getPosterior(bn, evs=evs, target=target)) 513 514 515def animApproximationScheme(apsc, scale=np.log10): 516 """ 517 show an animated version of an approximation algorithm 518 519 :param apsc: the approximation algorithm 520 :param scale: a function to apply to the figure 521 """ 522 f = plt.gcf() 523 524 h = gum.PythonApproximationListener( 525 apsc._asIApproximationSchemeConfiguration() 526 ) 527 apsc.setVerbosity(True) 528 apsc.listener = h 529 530 def stopper(x): 531 IPython.display.clear_output(True) 532 plt.title( 533 f"{x} \n Time : {apsc.currentTime()}s | Iterations : {apsc.nbrIterations()} | Epsilon : {apsc.epsilon()}" 534 ) 535 536 def progresser(x, y, z): 537 if len(apsc.history()) < 10: 538 plt.xlim(1, 10) 539 else: 540 plt.xlim(1, len(apsc.history())) 541 plt.plot(scale(apsc.history()), 'g') 542 IPython.display.clear_output(True) 543 IPython.display.display(f) 544 545 h.setWhenStop(stopper) 546 h.setWhenProgress(progresser) 547 548 549def showApproximationScheme(apsc, scale=np.log10): 550 if apsc.verbosity(): 551 if len(apsc.history()) < 10: 552 plt.xlim(1, 10) 553 else: 554 plt.xlim(1, len(apsc.history())) 555 plt.title( 556 f"Time : {apsc.currentTime()}s | Iterations : {apsc.nbrIterations()} | Epsilon : {apsc.epsilon()}" 557 ) 558 plt.plot(scale(apsc.history()), 'g') 559 560 561def showMN(mn, view=None, size=None, nodeColor=None, factorColor=None, edgeWidth=None, edgeColor=None, cmap=None, 562 cmapEdge=None 563 ): 564 """ 565 show a Markov network 566 567 :param mn: the markov network 568 :param view: 'graph' | 'factorgraph’ | None (default) 569 :param size: size of the rendered graph 570 :param nodeColor: a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1) 571 :param factorColor: a function returning a value (beeween 0 and 1) to be shown as a color of factor. (used when view='factorgraph') 572 :param edgeWidth: a edgeMap of values to be shown as width of edges (used when view='graph') 573 :param edgeColor: a edgeMap of values (between 0 and 1) to be shown as color of edges (used when view='graph') 574 :param cmap: color map to show the colors 575 :param cmapEdge: color map to show the edge color if distinction is needed 576 :return: the graph 577 """ 578 if view is None: 579 view = gum.config["notebook", "default_markovnetwork_view"] 580 581 if size is None: 582 size = gum.config["notebook", "default_graph_size"] 583 584 if cmapEdge is None: 585 cmapEdge = cmap 586 587 if view == "graph": 588 dottxt = MN2UGdot(mn, size, nodeColor, edgeWidth, 589 edgeColor, cmap, cmapEdge 590 ) 591 else: 592 dottxt = MN2FactorGraphdot(mn, size, nodeColor, factorColor, cmapNode=cmap) 593 594 return showGraph(dottxt, size) 595 596 597def showInfluenceDiagram(diag, size=None): 598 """ 599 show an influence diagram as a graph 600 601 :param diag: the influence diagram 602 :param size: size of the rendered graph 603 :return: the representation of the influence diagram 604 """ 605 if size is None: 606 size = gum.config["influenceDiagram", "default_id_size"] 607 608 return showGraph(ID2dot(diag), size) 609 610 611def getInfluenceDiagram(diag, size=None): 612 """ 613 get a HTML string for an influence diagram as a graph 614 615 :param diag: the influence diagram 616 :param size: size of the rendered graph 617 :return: the HTML representation of the influence diagram 618 """ 619 if size is None: 620 size = gum.config["influenceDiagram", "default_id_size"] 621 622 return getGraph(ID2dot(diag), size) 623 624 625def showBN(bn, size=None, nodeColor=None, arcWidth=None, arcColor=None, cmap=None, cmapArc=None): 626 """ 627 show a Bayesian network 628 629 :param bn: the Bayesian network 630 :param size: size of the rendered graph 631 :param nodeColor: a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1) 632 :param arcWidth: a arcMap of values to be shown as width of arcs 633 :param arcColor: a arcMap of values (between 0 and 1) to be shown as color of arcs 634 :param cmap: color map to show the colors 635 :param cmapArc: color map to show the arc color if distinction is needed 636 :return: the graph 637 """ 638 if size is None: 639 size = gum.config["notebook", "default_graph_size"] 640 641 if cmapArc is None: 642 cmapArc = cmap 643 644 return showGraph(BN2dot(bn, size, nodeColor, arcWidth, arcColor, cmap, cmapArc), size) 645 646 647def showCN(cn, size=None, nodeColor=None, arcWidth=None, arcColor=None, cmap=None, cmapArc=None): 648 """ 649 show a credal network 650 651 :param cn: the credal network 652 :param size: size of the rendered graph 653 :param nodeColor: a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1) 654 :param arcWidth: a arcMap of values to be shown as width of arcs 655 :param arcColor: a arcMap of values (between 0 and 1) to be shown as color of arcs 656 :param cmap: color map to show the colors 657 :param cmapArc: color map to show the arc color if distinction is needed 658 :return: the graph 659 """ 660 if size is None: 661 size = gum.config["notebook", "default_graph_size"] 662 663 if cmapArc is None: 664 cmapArc = cmap 665 666 return showGraph(CN2dot(cn, size, nodeColor, arcWidth, arcColor, cmap, cmapArc), size) 667 668 669def getMN(mn, view=None, size=None, nodeColor=None, factorColor=None, edgeWidth=None, edgeColor=None, cmap=None, 670 cmapEdge=None 671 ): 672 """ 673 get an HTML string for a Markov network 674 675 :param mn: the markov network 676 :param view: 'graph' | 'factorgraph’ | None (default) 677 :param size: size of the rendered graph 678 :param nodeColor: a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1) 679 :param factorColor: a function returning a value (beeween 0 and 1) to be shown as a color of factor. (used when view='factorgraph') 680 :param edgeWidth: a edgeMap of values to be shown as width of edges (used when view='graph') 681 :param edgeColor: a edgeMap of values (between 0 and 1) to be shown as color of edges (used when view='graph') 682 :param cmap: color map to show the colors 683 :param cmapEdge: color map to show the edge color if distinction is needed 684 :return: the graph 685 """ 686 if size is None: 687 size = gum.config["notebook", "default_graph_size"] 688 689 if cmapEdge is None: 690 cmapEdge = cmap 691 692 if view is None: 693 view = gum.config["notebook", "default_markovnetwork_view"] 694 695 if view == "graph": 696 dottxt = MN2UGdot(mn, size, nodeColor, edgeWidth, 697 edgeColor, cmap, cmapEdge 698 ) 699 else: 700 dottxt = MN2FactorGraphdot(mn, size, nodeColor, factorColor, cmapNode=cmap) 701 702 return getGraph(dottxt, size) 703 704 705def getBN(bn, size=None, nodeColor=None, arcWidth=None, arcColor=None, cmap=None, cmapArc=None): 706 """ 707 get a HTML string for a Bayesian network 708 709 :param bn: the Bayesian network 710 :param size: size of the rendered graph 711 :param nodeColor: a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1) 712 :param arcWidth: a arcMap of values to be shown as width of arcs 713 :param arcColor: a arcMap of values (between 0 and 1) to be shown as color of arcs 714 :param cmap: color map to show the colors 715 :param cmapArc: color map to show the arc color if distinction is needed 716 717 :return: the graph 718 """ 719 if size is None: 720 size = gum.config["notebook", "default_graph_size"] 721 722 if cmapArc is None: 723 cmapArc = cmap 724 725 return getGraph(BN2dot(bn, size, nodeColor, arcWidth, arcColor, cmap, cmapArc), size) 726 727 728def getCN(cn, size=None, nodeColor=None, arcWidth=None, arcColor=None, cmap=None, cmapArc=None): 729 """ 730 get a HTML string for a credal network 731 732 :param cn: the credal network 733 :param size: size of the rendered graph 734 :param nodeColor: a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1) 735 :param arcWidth: a arcMap of values to be shown as width of arcs 736 :param arcColor: a arcMap of values (between 0 and 1) to be shown as color of arcs 737 :param cmap: color map to show the colors 738 :param cmapArc: color map to show the arc color if distinction is needed 739 740 :return: the graph 741 """ 742 if size is None: 743 size = gum.config["notebook", "default_graph_size"] 744 745 if cmapArc is None: 746 cmapArc = cmap 747 748 return getGraph(CN2dot(cn, size, nodeColor, arcWidth, arcColor, cmap, cmapArc), size) 749 750 751def _get_showInference(model, engine=None, evs=None, targets=None, size=None, 752 nodeColor=None, factorColor=None, arcWidth=None, 753 arcColor=None, cmap=None, cmapArc=None, graph=None, view=None 754 ): 755 if size is None: 756 size = gum.config["notebook", "default_graph_inference_size"] 757 758 if evs is None: 759 evs = {} 760 761 if targets is None: 762 targets = {} 763 764 if isinstance(model, gum.BayesNet): 765 if engine is None: 766 engine = gum.LazyPropagation(model) 767 return BNinference2dot(model, size=size, engine=engine, evs=evs, targets=targets, nodeColor=nodeColor, 768 arcWidth=arcWidth, 769 arcColor=arcColor, 770 cmapNode=cmap, cmapArc=cmapArc 771 ) 772 elif isinstance(model, gum.MarkovNet): 773 if view is None: 774 view = gum.config["notebook", "default_markovnetwork_view"] 775 if engine is None: 776 engine = gum.ShaferShenoyMNInference(model) 777 778 if view == "graph": 779 return MNinference2UGdot(model, size=size, engine=engine, evs=evs, targets=targets, nodeColor=nodeColor, 780 factorColor=factorColor, 781 arcWidth=arcWidth, arcColor=arcColor, cmapNode=cmap, cmapArc=cmapArc 782 ) 783 else: 784 return MNinference2FactorGraphdot(model, size=size, engine=engine, evs=evs, targets=targets, 785 nodeColor=nodeColor, 786 factorColor=factorColor, cmapNode=cmap 787 ) 788 elif isinstance(model, gum.InfluenceDiagram): 789 if engine is None: 790 engine = gum.ShaferShenoyLIMIDInference(model) 791 return LIMIDinference2dot(model, size=size, engine=engine, evs=evs, targets=targets) 792 elif isinstance(model, gum.CredalNet): 793 if engine is None: 794 engine = gum.CNMonteCarloSampling(model) 795 return CNinference2dot(model, size=size, engine=engine, evs=evs, targets=targets, nodeColor=nodeColor, 796 arcWidth=arcWidth, arcColor=arcColor, cmapNode=cmap 797 ) 798 else: 799 raise gum.InvalidArgument( 800 "Argument model should be a PGM (BayesNet, MarkovNet or Influence Diagram" 801 ) 802 803 804def showInference(model, **kwargs): 805 """ 806 show pydot graph for an inference in a notebook 807 808 :param GraphicalModel model: the model in which to infer (pyAgrum.BayesNet, pyAgrum.MarkovNet or pyAgrum.InfluenceDiagram) 809 :param gum.Inference engine: inference algorithm used. If None, gum.LazyPropagation will be used for BayesNet, gum.ShaferShenoy for gum.MarkovNet and gum.ShaferShenoyLIMIDInference for gum.InfluenceDiagram. 810 :param dictionnary evs: map of evidence 811 :param set targets: set of targets 812 :param string size: size of the rendered graph 813 :param nodeColor: a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1) 814 :param factorColor: a nodeMap of values (between 0 and 1) to be shown as color of factors (in MarkovNet representation) 815 :param arcWidth: a arcMap of values to be shown as width of arcs 816 :param arcColor: a arcMap of values (between 0 and 1) to be shown as color of arcs 817 :param cmap: color map to show the color of nodes and arcs 818 :param cmapArc: color map to show the vals of Arcs. 819 :param graph: only shows nodes that have their id in the graph (and not in the whole BN) 820 :param view: graph | factorgraph | None (default) for Markov network 821 :return: the desired representation of the inference 822 """ 823 if "size" in kwargs: 824 size = kwargs['size'] 825 else: 826 size = gum.config["notebook", "default_graph_inference_size"] 827 828 showGraph(_get_showInference(model, **kwargs), size) 829 830 831def getInference(model, **kwargs): 832 """ 833 get a HTML string for an inference in a notebook 834 835 :param GraphicalModel model: the model in which to infer (pyAgrum.BayesNet, pyAgrum.MarkovNet or 836 pyAgrum.InfluenceDiagram) 837 :param gum.Inference engine: inference algorithm used. If None, gum.LazyPropagation will be used for BayesNet, 838 gum.ShaferShenoy for gum.MarkovNet and gum.ShaferShenoyLIMIDInference for gum.InfluenceDiagram. 839 :param dictionnary evs: map of evidence 840 :param set targets: set of targets 841 :param string size: size of the rendered graph 842 :param nodeColor: a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1) 843 :param factorColor: a nodeMap of values (between 0 and 1) to be shown as color of factors (in MarkovNet representation) 844 :param arcWidth: a arcMap of values to be shown as width of arcs 845 :param arcColor: a arcMap of values (between 0 and 1) to be shown as color of arcs 846 :param cmap: color map to show the color of nodes and arcs 847 :param cmapArc: color map to show the vals of Arcs. 848 :param graph: only shows nodes that have their id in the graph (and not in the whole BN) 849 :param view: graph | factorgraph | None (default) for Markov network 850 851 :return: the desired representation of the inference 852 """ 853 if "size" in kwargs: 854 size = kwargs['size'] 855 else: 856 size = gum.config["notebook", "default_graph_inference_size"] 857 858 grinf = _get_showInference(model, **kwargs) 859 return getGraph(grinf, size) 860 861 862def _reprPotential(pot, digits=None, withColors=True, varnames=None, asString=False): 863 """ 864 return a representation of a gum.Potential as a HTML table. 865 The first dimension is special (horizontal) due to the representation of conditional probability table 866 867 :param pot: the potential to get 868 :param digits: number of digits to show 869 :param withColors: bgcolor for proba cells or not 870 :param varnames: the aliases for variables name in the table 871 :param asString: display the table or a HTML string 872 873 :return: the representation 874 """ 875 from IPython.core.display import HTML 876 877 878 r0, g0, b0 = gumcols.hex2rgb(gum.config['notebook', 'potential_color_0']) 879 r1, g1, b1 = gumcols.hex2rgb(gum.config['notebook', 'potential_color_1']) 880 881 def _rgb(r, g, b): 882 return '#%02x%02x%02x' % (r, g, b) 883 884 def _mkCell(val): 885 s = "<td style='" 886 if withColors and (0 <= val <= 1): 887 r = int(r0 + val * (r1 - r0)) 888 g = int(g0 + val * (g1 - g0)) 889 b = int(b0 + val * (b1 - b0)) 890 891 tx = gumcols.rgb2brightness(r,g,b) 892 893 s += "color:" + tx + ";background-color:" + _rgb(r, g, b) + ";" 894 s += "text-align:right;'>{:." + str(digits) + "f}</td>" 895 return s.format(val) 896 897 if digits is None: 898 digits = gum.config['notebook', 'potential_visible_digits'] 899 900 if gum.config["notebook", "potential_with_colors"] == "False": 901 withColors = False 902 903 html = list() 904 html.append('<table style="border:1px solid black;">') 905 if pot.empty(): 906 html.append( 907 "<tr><th> </th></tr>" 908 ) 909 html.append("<tr>" + _mkCell(pot.get(gum.Instantiation())) + "</tr>") 910 else: 911 if varnames is not None and len(varnames) != pot.nbrDim(): 912 raise ValueError( 913 f"varnames contains {len(varnames)} value(s) instead of the needed {pot.nbrDim()} value(s)." 914 ) 915 916 nparents = pot.nbrDim() - 1 917 var = pot.variable(0) 918 varname = var.name() if varnames == None else varnames[0] 919 920 # first line 921 if nparents > 0: 922 html.append(f"""<tr><th colspan='{nparents}'></th> 923 <th colspan='{var.domainSize()}' style='border:1px solid black;color:black;background-color:#808080;'><center>{varname}</center> 924 </th></tr>""" 925 ) 926 else: 927 html.append(f"""<tr style='border:1px solid black;color:black;background-color:#808080'> 928 <th colspan='{var.domainSize()}'><center>{varname}</center></th></tr>""" 929 ) 930 931 # second line 932 s = "<tr>" 933 if nparents > 0: 934 # parents order 935 if gum.config["notebook", "potential_parent_values"] == "revmerge": 936 pmin, pmax, pinc = nparents - 1, 0 - 1, -1 937 else: 938 pmin, pmax, pinc = 0, nparents, 1 939 for par in range(pmin, pmax, pinc): 940 parent = pot.var_names[par] if varnames is None else varnames[par] 941 s += f"<th style='border:1px solid black;color:black;background-color:#808080'><center>{parent}</center></th>" 942 943 for label in var.labels(): 944 s += f"""<th style='border:1px solid black;border-bottom-style: double;color:black;background-color:#BBBBBB'> 945 <center>{label}</center></th>""" 946 s += '</tr>' 947 948 html.append(s) 949 950 inst = gum.Instantiation(pot) 951 off = 1 952 offset = dict() 953 for i in range(1, nparents + 1): 954 offset[i] = off 955 off *= inst.variable(i).domainSize() 956 957 inst.setFirst() 958 while not inst.end(): 959 s = "<tr>" 960 # parents order 961 if gum.config["notebook", "potential_parent_values"] == "revmerge": 962 pmin, pmax, pinc = 1, nparents + 1, 1 963 else: 964 pmin, pmax, pinc = nparents, 0, -1 965 for par in range(pmin, pmax, pinc): 966 label = inst.variable(par).label(inst.val(par)) 967 if par == 1 or gum.config["notebook", "potential_parent_values"] == "nomerge": 968 s += f"<th style='border:1px solid black;color:black;background-color:#BBBBBB'><center>{label}</center></th>" 969 else: 970 if sum([inst.val(i) for i in range(1, par)]) == 0: 971 s += f"""<th style='border:1px solid black;color:black;background-color:#BBBBBB;' rowspan = '{offset[par]}'> 972 <center>{label}</center></th>""" 973 for j in range(pot.variable(0).domainSize()): 974 s += _mkCell(pot.get(inst)) 975 inst.inc() 976 s += "</tr>" 977 html.append(s) 978 979 html.append("</table>") 980 981 if asString: 982 return "\n".join(html) 983 else: 984 return IPython.display.HTML("".join(html)) 985 986 987def __isKindOfProba(pot): 988 """ 989 check if pot is a joint proba or a CPT 990 :param pot: the potential 991 :return: True or False 992 """ 993 epsilon = 1e-5 994 if pot.min() < -epsilon: 995 return False 996 if pot.max() > 1 + epsilon: 997 return False 998 999 # is it a joint proba ? 1000 if abs(pot.sum() - 1) < epsilon: 1001 return True 1002 1003 # marginal and then not proba (because of the test just above) 1004 if pot.nbrDim() < 2: 1005 return False 1006 1007 # is is a CPT ? 1008 q = pot.margSumOut([pot.variable(0).name()]) 1009 if abs(q.max() - 1) > epsilon: 1010 return False 1011 if abs(q.min() - 1) > epsilon: 1012 return False 1013 return True 1014 1015 1016def showPotential(pot, digits=None, withColors=None, varnames=None): 1017 """ 1018 show a gum.Potential as a HTML table. 1019 The first dimension is special (horizontal) due to the representation of conditional probability table 1020 1021 :param gum.Potential pot: the potential to get 1022 :param int digits: number of digits to show 1023 :param: boolean withColors : bgcolor for proba cells or not 1024 :param list of strings varnames: the aliases for variables name in the table 1025 :return: the display of the potential 1026 """ 1027 if withColors is None: 1028 if gum.config["notebook", "potential_with_colors"] == "False": 1029 withColors = False 1030 else: 1031 withColors = __isKindOfProba(pot) 1032 1033 s = _reprPotential(pot, digits, withColors, varnames, asString=False) 1034 IPython.display.display(s) 1035 1036 1037def getPotential(pot, digits=None, withColors=None, varnames=None): 1038 """ 1039 return a HTML string of a gum.Potential as a HTML table. 1040 The first dimension is special (horizontal) due to the representation of conditional probability table 1041 1042 :param gum.Potential pot: the potential to get 1043 :param int digits: number of digits to show 1044 :param: boolean withColors : bgcolor for proba cells or not 1045 :param list of strings varnames: the aliases for variables name in the table 1046 :return: the HTML string 1047 """ 1048 if withColors is None: 1049 if gum.config["notebook", "potential_with_colors"] == "False": 1050 withColors = False 1051 else: 1052 withColors = __isKindOfProba(pot) 1053 1054 return _reprPotential(pot, digits, withColors, varnames, asString=True) 1055 1056 1057def getSideBySide(*args, **kwargs): 1058 """ 1059 create an HTML table for args as string (using string, _repr_html_() or str()) 1060 1061 :param args: HMTL fragments as string arg, arg._repr_html_() or str(arg) 1062 :param captions: list of strings (captions) 1063 :return: a string representing the table 1064 """ 1065 vals = {'captions', 'valign'} 1066 if not set(kwargs.keys()).issubset(vals): 1067 raise TypeError(f"sideBySide() got unexpected keyword argument(s) : '{set(kwargs.keys()).difference(vals)}'") 1068 1069 if 'captions' in kwargs: 1070 captions = kwargs['captions'] 1071 else: 1072 captions = None 1073 1074 if 'valign' in kwargs: 1075 v_align = 'vertical-align:' + kwargs['valign'] + ';' 1076 else: 1077 v_align = "" 1078 1079 s = '<table style="border-style: hidden; border-collapse: collapse;" width="100%">' 1080 1081 def reprHTML(s): 1082 if isinstance(s, str): 1083 return s 1084 elif hasattr(s, '_repr_html_'): 1085 return s._repr_html_() 1086 else: 1087 return str(s) 1088 1089 s += '<tr><td style="border-top:hidden;border-bottom:hidden;' + v_align + '"><div align="center" style="' + v_align \ 1090 + '">' 1091 s += ( 1092 '</div></td><td style="border-top:hidden;border-bottom:hidden;' + v_align + '"><div align="center" style="' + 1093 v_align + '">').join( 1094 [reprHTML(arg) 1095 for arg in args] 1096 ) 1097 s += '</div></td></tr>' 1098 1099 if captions is not None: 1100 s += '<tr><td style="border-top:hidden;border-bottom:hidden;"><div align="center"><small>' 1101 s += '</small></div></td><td style="border-top:hidden;border-bottom:hidden;"><div align="center"><small>'.join( 1102 captions 1103 ) 1104 s += '</small></div></td></tr>' 1105 1106 s += '</table>' 1107 return s 1108 1109 1110def sideBySide(*args, **kwargs): 1111 """ 1112 display side by side args as HMTL fragment (using string, _repr_html_() or str()) 1113 1114 :param args: HMTL fragments as string arg, arg._repr_html_() or str(arg) 1115 :param captions: list of strings (captions) 1116 """ 1117 IPython.display.display(IPython.display.HTML(getSideBySide(*args, **kwargs))) 1118 1119 1120def getInferenceEngine(ie, inferenceCaption): 1121 """ 1122 display an inference as a BN+ lists of hard/soft evidence and list of targets 1123 1124 :param gum.InferenceEngine ie: inference engine 1125 :param string inferenceCaption: title for caption 1126 1127 """ 1128 t = '<div align="left"><ul>' 1129 if ie.nbrHardEvidence() > 0: 1130 t += "<li><b>hard evidence</b><br/>" 1131 t += ", ".join([ie.BN().variable(n).name() 1132 for n in ie.hardEvidenceNodes()] 1133 ) 1134 t += "</li>" 1135 if ie.nbrSoftEvidence() > 0: 1136 t += "<li><b>soft evidence</b><br/>" 1137 t += ", ".join([ie.BN().variable(n).name() 1138 for n in ie.softEvidenceNodes()] 1139 ) 1140 t += "</li>" 1141 if ie.nbrTargets() > 0: 1142 t += "<li><b>target(s)</b><br/>" 1143 if ie.nbrTargets() == ie.BN().size(): 1144 t += " all" 1145 else: 1146 t += ", ".join([ie.BN().variable(n).name() for n in ie.targets()]) 1147 t += "</li>" 1148 1149 if hasattr(ie, 'nbrJointTargets'): 1150 if ie.nbrJointTargets() > 0: 1151 t += "<li><b>Joint target(s)</b><br/>" 1152 t += ", ".join(['[' 1153 + (", ".join([ie.BN().variable(n).name() 1154 for n in ns] 1155 )) 1156 + ']' for ns in ie.jointTargets()] 1157 ) 1158 t += "</li>" 1159 t += '</ul></div>' 1160 return getSideBySide(getBN(ie.BN()), t, captions=[inferenceCaption, "Evidence and targets"]) 1161 1162 1163def getJT(jt, size=None): 1164 if gum.config["notebook", "junctiontree_with_names"] == "True": 1165 def cliqlabels(c): 1166 return " ".join( 1167 sorted([model.variable(n).name() for n in jt.clique(c)]) 1168 ) 1169 1170 def cliqnames( 1171 c 1172 ): 1173 return "-".join(sorted([model.variable(n).name() for n in jt.clique(c)])) 1174 1175 def seplabels(c1, c2): 1176 return " ".join( 1177 sorted([model.variable(n).name() for n in jt.separator(c1, c2)]) 1178 ) 1179 1180 def sepnames(c1, c2): 1181 return cliqnames(c1) + '+' + cliqnames(c2) 1182 else: 1183 def cliqlabels(c): 1184 return " ".join([str(n) for n in sorted(jt.clique(c))]) 1185 1186 def cliqnames(c): 1187 return "-".join([str(n) for n in sorted(jt.clique(c))]) 1188 1189 def seplabels(c1, c2): 1190 return " ".join( 1191 [str(n) for n in sorted(jt.separator(c1, c2))] 1192 ) 1193 1194 def sepnames(c1, c2): 1195 return cliqnames(c1) + '^' + cliqnames(c2) 1196 1197 model = jt._engine._model 1198 name = model.propertyWithDefault( 1199 "name", str(type(model)).split(".")[-1][:-2] 1200 ) 1201 graph = dot.Dot(graph_type='graph', graph_name=name, bgcolor="transparent") 1202 for c in jt.nodes(): 1203 graph.add_node(dot.Node('"' + cliqnames(c) + '"', 1204 label='"' + cliqlabels(c) + '"', 1205 style="filled", 1206 fillcolor=gum.config["notebook", 1207 "junctiontree_clique_bgcolor"], 1208 fontcolor=gum.config["notebook", 1209 "junctiontree_clique_fgcolor"], 1210 fontsize=gum.config["notebook", "junctiontree_clique_fontsize"] 1211 ) 1212 ) 1213 for c1, c2 in jt.edges(): 1214 graph.add_node(dot.Node('"' + sepnames(c1, c2) + '"', 1215 label='"' + seplabels(c1, c2) + '"', 1216 style="filled", 1217 shape="box", width="0", height="0", margin="0.02", 1218 fillcolor=gum.config["notebook", 1219 "junctiontree_separator_bgcolor"], 1220 fontcolor=gum.config["notebook", 1221 "junctiontree_separator_fgcolor"], 1222 fontsize=gum.config["notebook", "junctiontree_separator_fontsize"] 1223 ) 1224 ) 1225 for c1, c2 in jt.edges(): 1226 graph.add_edge(dot.Edge('"' + cliqnames(c1) + 1227 '"', '"' + sepnames(c1, c2) + '"' 1228 ) 1229 ) 1230 graph.add_edge(dot.Edge('"' + sepnames(c1, c2) + 1231 '"', '"' + cliqnames(c2) + '"' 1232 ) 1233 ) 1234 1235 if size is None: 1236 size = gum.config["notebook", "default_graph_size"] 1237 graph.set_size(gum.config["notebook", "junctiontree_graph_size"]) 1238 1239 return graph.to_string() 1240 1241 1242def getCliqueGraph(cg, size=None): 1243 """get a representation for clique graph. Special case for junction tree 1244 (clique graph with an attribute `_engine`) 1245 1246 Args: 1247 cg (gum.CliqueGraph): the clique graph (maybe junction tree for a _model) to 1248 represent 1249 """ 1250 if hasattr(cg, "_engine"): 1251 return getDot(getJT(cg), size) 1252 else: 1253 return getDot(cg.toDot()) 1254 1255 1256def show(model, **kwargs): 1257 """ 1258 propose a (visual) representation of a model in a notebook 1259 1260 :param GraphicalModel model: the model to show (pyAgrum.BayesNet, pyAgrum.MarkovNet, pyAgrum.InfluenceDiagram or pyAgrum.Potential) 1261 1262 :param int size: optional size for the graphical model (no effect for Potential) 1263 """ 1264 if isinstance(model, gum.BayesNet): 1265 showBN(model, **kwargs) 1266 elif isinstance(model, gum.MarkovNet): 1267 showMN(model, **kwargs) 1268 elif isinstance(model, gum.InfluenceDiagram): 1269 showInfluenceDiagram(model, **kwargs) 1270 elif isinstance(model, gum.CredalNet): 1271 showCN(model, **kwargs) 1272 elif isinstance(model, gum.Potential): 1273 showPotential(model) 1274 elif hasattr(model, "toDot"): 1275 showDot(dot.graph_from_dot_data(model.toDot()), **kwargs) 1276 else: 1277 raise gum.InvalidArgument( 1278 "Argument model should be a PGM (BayesNet, MarkovNet, Influence Diagram or Potential or ..." 1279 ) 1280 1281 1282def export(model, filename, **kwargs): 1283 """ 1284 export the graphical representation of the model in filename (png, pdf,etc.) 1285 1286 :param GraphicalModel model: the model to show (pyAgrum.BayesNet, pyAgrum.MarkovNet, pyAgrum.InfluenceDiagram or pyAgrum.Potential) 1287 :param str filename: the name of the resulting file (suffix in ['pdf', 'png', 'fig', 'jpg', 'svg', 'ps']) 1288 """ 1289 format = filename.split(".")[-1] 1290 if format not in ['pdf', 'png', 'fig', 'jpg', 'svg', 'ps']: 1291 raise Exception( 1292 f"{filename} in not a correct filename for export : extension '{format}' not in [pdf,png,fig,jpg,svg]." 1293 ) 1294 1295 if isinstance(model, gum.BayesNet): 1296 fig = BN2dot(model, **kwargs) 1297 elif isinstance(model, gum.MarkovNet): 1298 if gum.config["notebook", "default_markovnetwork_view"] == "graph": 1299 fig = MN2UGdot(model, **kwargs) 1300 else: 1301 fig = MN2FactorGraphdot(model, **kwargs) 1302 elif isinstance(model, gum.InfluenceDiagram): 1303 fig = ID2dot(model, **kwargs) 1304 elif isinstance(model, gum.CredalNet): 1305 fig = CN2dot(model, **kwargs) 1306 elif hasattr(model, "toDot"): 1307 fig = dot.graph_from_dot_data(model.toDot(), **kwargs) 1308 else: 1309 raise gum.InvalidArgument( 1310 "Argument model should be a PGM (BayesNet, MarkovNet or Influence Diagram" 1311 ) 1312 fig.write(filename, format=format) 1313 1314 1315def exportInference(model, filename, **kwargs): 1316 """ 1317 the graphical representation of an inference in a notebook 1318 1319 :param GraphicalModel model: the model in which to infer (pyAgrum.BayesNet, pyAgrum.MarkovNet or 1320 pyAgrum.InfluenceDiagram) 1321 :param str filename: the name of the resulting file (suffix in ['pdf', 'png', 'ps']) 1322 :param gum.Inference engine: inference algorithm used. If None, gum.LazyPropagation will be used for BayesNet, 1323 gum.ShaferShenoy for gum.MarkovNet and gum.ShaferShenoyLIMIDInference for gum.InfluenceDiagram. 1324 :param dictionnary evs: map of evidence 1325 :param set targets: set of targets 1326 :param string size: size of the rendered graph 1327 :param nodeColor: a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1) 1328 :param factorColor: a nodeMap of values (between 0 and 1) to be shown as color of factors (in MarkovNet representation) 1329 :param arcWidth: a arcMap of values to be shown as width of arcs 1330 :param arcColor: a arcMap of values (between 0 and 1) to be shown as color of arcs 1331 :param cmap: color map to show the color of nodes and arcs 1332 :param cmapArc: color map to show the vals of Arcs. 1333 :param graph: only shows nodes that have their id in the graph (and not in the whole BN) 1334 :param view: graph | factorgraph | None (default) for Markov network 1335 :return: the desired representation of the inference 1336 """ 1337 format = filename.split(".")[-1] 1338 if format not in ['pdf', 'png', 'ps']: 1339 raise Exception( 1340 f"{filename} in not a correct filename for export : extension '{format}' not in [pdf,png,ps]." 1341 ) 1342 1343 import cairosvg 1344 1345 if "size" in kwargs: 1346 size = kwargs['size'] 1347 else: 1348 size = gum.config["notebook", "default_graph_inference_size"] 1349 1350 svgtxt = _reprGraph( 1351 _get_showInference(model, **kwargs), size=size, asString=True, format="svg" 1352 ) 1353 1354 if format == "pdf": 1355 cairosvg.svg2pdf(bytestring=svgtxt, write_to=filename) 1356 elif format == "png": 1357 cairosvg.svg2png(bytestring=svgtxt, write_to=filename) 1358 else: # format=="ps" 1359 cairosvg.svg2ps(bytestring=svgtxt, write_to=filename) 1360 1361 1362def _update_config(): 1363 # hook to control some parameters for notebook when config changes 1364 mpl.rcParams['figure.facecolor'] = gum.config["notebook", "figure_facecolor"] 1365 set_matplotlib_formats(gum.config["notebook", "graph_format"]) 1366 1367 1368# check if an instance of ipython exists 1369try: 1370 get_ipython 1371except NameError as e: 1372 import warnings 1373 1374 warnings.warn(""" 1375 ** pyAgrum.lib.notebook has to be imported from an IPython's instance (mainly notebook). 1376 """ 1377 ) 1378else: 1379 gum.config.add_hook(_update_config) 1380 gum.config.run_hooks() 1381 1382 # adding _repr_html_ to some pyAgrum classes ! 1383 gum.BayesNet._repr_html_ = lambda self: getBN(self) 1384 gum.BayesNetFragment._repr_html_ = lambda self: getBN(self) 1385 gum.MarkovNet._repr_html_ = lambda self: getMN(self) 1386 gum.BayesNetFragment._repr_html_ = lambda self: getBN(self) 1387 gum.InfluenceDiagram._repr_html_ = lambda self: getInfluenceDiagram(self) 1388 gum.CredalNet._repr_html_ = lambda self: getCN(self) 1389 1390 gum.CliqueGraph._repr_html_ = lambda self: getCliqueGraph(self) 1391 1392 gum.Potential._repr_html_ = lambda self: getPotential(self) 1393 gum.LazyPropagation._repr_html_ = lambda self: getInferenceEngine( 1394 self, "Lazy Propagation on this BN" 1395 ) 1396 1397 gum.UndiGraph._repr_html_ = lambda self: getDot(self.toDot()) 1398 gum.DiGraph._repr_html_ = lambda self: getDot(self.toDot()) 1399 gum.MixedGraph._repr_html_ = lambda self: getDot(self.toDot()) 1400 gum.DAG._repr_html_ = lambda self: getDot(self.toDot()) 1401 gum.EssentialGraph._repr_html_ = lambda self: getDot(self.toDot()) 1402 gum.MarkovBlanket._repr_html_ = lambda self: getDot(self.toDot()) 1403