1# (c) Copyright by Pierre-Henri Wuillemin, UPMC, 2017
2# (pierre-henri.wuillemin@lip6.fr)
3# Permission to use, copy, modify, and distribute this
4# software and its documentation for any purpose and
5# without fee or royalty is hereby granted, provided
6# that the above copyright notice appear in all copies
7# and that both that copyright notice and this permission
8# notice appear in supporting documentation or portions
9# thereof, including modifications, that you make.
10# THE AUTHOR P.H. WUILLEMIN  DISCLAIMS ALL WARRANTIES
11# WITH REGARD TO THIS SOFTWARE, INCLUDING ALL IMPLIED
12# WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT
13# SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT
14# OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER
15# RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER
16# IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
17# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE
18# OR PERFORMANCE OF THIS SOFTWARE!
19
20"""
21tools for BN in jupyter notebook
22"""
23
24import time
25import re
26import sys
27
28# fix DeprecationWarning of base64.encodestring()
29try:
30  from base64 import encodebytes
31except ImportError:  # 3+
32  from base64 import encodestring as encodebytes
33
34import io
35import base64
36
37import matplotlib as mpl
38import matplotlib.pyplot as plt
39
40try:
41  from matplotlib_inline.backend_inline import set_matplotlib_formats as set_matplotlib_formats
42except ImportError:  # because of python 2.7, matplotlib-inline cannot be part of requirements.txt
43  def set_matplotlib_formats(*args, **kwargs):
44    # dummy version when no matplotlib_inline package
45    if sys.version[0] == "3":
46      print("** pyAgrum** For better visualizations, please install matplotlib-inline.")
47    pass
48
49import numpy as np
50import pydotplus as dot
51
52import IPython.core.display
53import IPython.core.pylabtools
54import IPython.display
55
56import pyAgrum as gum
57from pyAgrum.lib.bn2graph import BN2dot, BNinference2dot
58from pyAgrum.lib.cn2graph import CN2dot, CNinference2dot
59from pyAgrum.lib.id2graph import ID2dot, LIMIDinference2dot
60from pyAgrum.lib.mn2graph import MN2UGdot, MNinference2UGdot
61from pyAgrum.lib.mn2graph import MN2FactorGraphdot, MNinference2FactorGraphdot
62from pyAgrum.lib.bn_vs_bn import GraphicalBNComparator
63from pyAgrum.lib.proba_histogram import proba2histo, probaMinMaxH
64
65from pyAgrum.lib._colors import setDarkTheme,setLightTheme,getBlackInTheme
66from pyAgrum.lib._colors import forDarkTheme,forLightTheme # obsolete since 0.21.0
67
68import pyAgrum.lib._colors as gumcols
69
70
71class FlowLayout(object):
72  """"
73  A class / object to display plots in a horizontal / flow layout below a cell
74
75  based on : https://stackoverflow.com/questions/21754976/ipython-notebook-arrange-plots-horizontally
76  """
77
78  def __init__(self):
79    self.clear()
80
81  def clear(self):
82    """
83    clear the flow
84    """
85    # string buffer for the HTML: initially some CSS; images to be appended
86    self.sHtml = """
87      <style>
88      .floating-box {
89      display: inline-block;
90      margin: 7px;
91      padding : 3px;
92      border: 2px solid #FFFFFF;
93      valign:middle;
94      background-color: #FDFDFD;
95      }
96      </style>
97      """
98
99  def _getTitle(self, title):
100    if title == "":
101      return ""
102    return f"<br><center><small><em>{title}</em></small></center>"
103
104  def add_html(self, html, title=""):
105    """
106    add an html element in the row
107    """
108    self.sHtml += f'<div class="floating-box">{html}{self._getTitle(title)}</div>'
109
110  def add_separator(self, size=3):
111    """
112    add a (poor) separation between elements in a row
113    """
114    self.add_html("&nbsp;" * size)
115
116  def add_plot(self, oAxes, title=""):
117    """
118    Add a PNG representation of a Matplotlib Axes object
119    """
120    Bio = io.BytesIO()  # bytes buffer for the plot
121    fig = oAxes.get_figure()
122    fig.canvas.print_png(Bio)  # make a png of the plot in the buffer
123
124    # encode the bytes as string using base 64
125    sB64Img = base64.b64encode(Bio.getvalue()).decode()
126    self.sHtml += f'<div class="floating-box"><img src="data:image/png;base64,{sB64Img}\n">{self._getTitle(title)}</div>'
127    plt.close()
128
129  def new_line(self):
130    """
131    add a breakline (a new row)
132    """
133    self.sHtml += '<br/>'
134
135  def html(self):
136    """
137    Returns its content as HTML object
138    """
139    return IPython.display.HTML(self.sHtml)
140
141  def display(self):
142    """
143    Display the accumulated HTML
144    """
145    IPython.display.display(self.html())
146    self.clear()
147
148  def row(self, *args, captions=None):
149    self.clear()
150    for i, arg in enumerate(args):
151      if captions is None:
152        t = ""
153      else:
154        t = captions[i]
155
156      if hasattr(arg, "get_figure"):
157        self.add_plot(arg, title=t)
158      elif hasattr(arg, "_repr_html_"):
159        self.add_html(arg._repr_html_(), title=t)
160      else:
161        self.add_html(arg, title=t)
162
163    self.display()
164
165
166flow = FlowLayout()
167
168
169def configuration():
170  """
171  Display the collection of dependance and versions
172  """
173  from collections import OrderedDict
174  import sys
175  import os
176
177  packages = OrderedDict()
178  packages["OS"] = "%s [%s]" % (os.name, sys.platform)
179  packages["Python"] = sys.version
180  packages["IPython"] = IPython.__version__
181  packages["MatPlotLib"] = mpl.__version__
182  packages["Numpy"] = np.__version__
183  packages["pyAgrum"] = gum.__version__
184
185  res = "<table width='100%'><tr><th>Library</th><th>Version</th></tr>"
186
187  for name in packages:
188    res += "<tr><td>%s</td><td>%s</td></tr>" % (name, packages[name])
189
190  res += "</table><div align='right'><small>%s</small></div>" % time.strftime(
191    '%a %b %d %H:%M:%S %Y %Z'
192  )
193
194  IPython.display.display(IPython.display.HTML(res))
195
196
197def __insertLinkedSVGs(mainSvg):
198  re_buggwhitespace = re.compile(r"(<image [^>]*>)")
199  re_images = re.compile(r"(<image [^>]*>)")
200  re_xlink = re.compile(r"xlink:href=\"([^\"]*)")
201  re_viewbox = re.compile(r"(viewBox=\"[^\"]*\")")
202
203  # analyze mainSvg (find the secondary svgs)
204  __fragments = {}
205  for img in re.finditer(re_images, mainSvg):
206    # print(img)
207    secondarySvg = re.findall(re_xlink, img.group(1))[0]
208    content = ""
209    with open(secondarySvg, encoding='utf8') as f:
210      inSvg = False
211      for line in f:
212        if line[0:4] == "<svg":
213          inSvg = True
214          viewBox = re.findall(re_viewbox, line)[0]
215          # print("VIEWBOX {}".format(viewBox))
216        elif inSvg:
217          content += line
218    __fragments[secondarySvg] = (viewBox, content)
219
220  if len(__fragments) > 0:
221    # replace image tags by svg tags
222    img2svg = re.sub(r"<image ([^>]*)/>", "<svg \g<1>>", mainSvg)
223
224    # insert secondaries into main
225    def ___insertSecondarySvgs(matchObj):
226      vb, code = __fragments[matchObj.group(1)]
227      return vb + matchObj.group(2) + code
228
229    mainSvg = re.sub(r'xlink:href="([^"]*)"(.*>)',
230                     ___insertSecondarySvgs, img2svg
231                     )
232
233  # remove buggy white-space (for notebooks)
234  mainSvg = mainSvg.replace("white-space:pre;", "")
235  return mainSvg
236
237
238def _reprGraph(gr, size, asString, format=None):
239  """
240  repr a pydot graph in a notebook
241
242  :param string size : size of the rendered graph
243  :param boolean asString : display the graph or return a string containing the corresponding HTML fragment
244  """
245  if size is not None:
246    gr.set_size(size)
247
248  if format is None:
249    format = gum.config["notebook", "graph_format"]
250
251  if format == "svg":
252    gsvg = IPython.display.SVG(__insertLinkedSVGs(gr.create_svg().decode('utf-8')))
253    if asString:
254      return gsvg.data
255    else:
256      IPython.display.display(gsvg)
257  else:
258    i = IPython.core.display.Image(format="png", data=gr.create_png())
259    if asString:
260      return f'<img style="margin:0" src="data:image/png;base64,{encodebytes(i.data).decode()}"/>'
261    else:
262      IPython.core.display.display_png(i)
263
264
265def showGraph(gr, size=None):
266  """
267  show a pydot graph in a notebook
268
269  :param gr: pydot graph
270  :param size:  size of the rendered graph
271  :return: the representation of the graph
272  """
273  if size is None:
274    size = gum.config["notebook", "default_graph_size"]
275
276  return _reprGraph(gr, size, asString=False)
277
278
279def getGraph(gr, size=None):
280  """
281  get a HTML string representation of pydot graph
282
283  :param gr: pydot graph
284  :param size:  size of the rendered graph
285  :param format: render as "png" or "svg"
286  :return: the HTML representation of the graph as a string
287  """
288  if size is None:
289    size = gum.config["notebook", "default_graph_size"]
290  return _reprGraph(gr, size, asString=True)
291
292
293def _from_dotstring(dotstring):
294  g = dot.graph_from_dot_data(dotstring)
295  g.set_bgcolor("transparent")
296  for e in g.get_edges():
297    if e.get_color() is None:
298      e.set_color(getBlackInTheme())
299  for n in g.get_nodes():
300    if n.get_color() is None:
301      n.set_color(getBlackInTheme())
302    if n.get_fontcolor() is None:
303      n.set_fontcolor(getBlackInTheme())
304  return g
305
306
307def showDot(dotstring, size=None):
308  """
309  show a dot string as a graph
310
311  :param dotstring: dot string
312  :param size: size of the rendered graph
313  :return: the representation of the graph
314  """
315  if size is None:
316    size = gum.config["notebook", "default_graph_size"]
317  return showGraph(_from_dotstring(dotstring), size)
318
319
320def getDot(dotstring, size=None):
321  """
322  get a dot string as a HTML string
323
324  :param dotstring: dot string
325  :param size: size of the rendered graph
326  :param format: render as "png" or "svg"
327  :param bg: color for background
328  :return: the HTML representation of the graph
329  """
330  if size is None:
331    size = gum.config["notebook", "default_graph_size"]
332
333  return getGraph(_from_dotstring(dotstring), size)
334
335
336def getBNDiff(bn1, bn2, size=None):
337  """ get a HTML string representation of a graphical diff between the arcs of _bn1 (reference) with those of _bn2.
338
339  * full black line: the arc is common for both
340  * full red line: the arc is common but inverted in _bn2
341  * dotted black line: the arc is added in _bn2
342  * dotted red line: the arc is removed in _bn2
343
344  :param BayesNet bn1: referent model for the comparison
345  :param BayesNet bn2: bn compared to the referent model
346  :param size: size of the rendered graph
347  """
348  if size is None:
349    size = gum.config["notebook", "default_graph_size"]
350  cmp = GraphicalBNComparator(bn1, bn2)
351  return getGraph(cmp.dotDiff(), size)
352
353
354def showBNDiff(bn1, bn2, size=None):
355  """ show a graphical diff between the arcs of _bn1 (reference) with those of _bn2.
356
357  * full black line: the arc is common for both
358  * full red line: the arc is common but inverted in _bn2
359  * dotted black line: the arc is added in _bn2
360  * dotted red line: the arc is removed in _bn2
361
362  :param BayesNet bn1: referent model for the comparison
363  :param BayesNet bn2: bn compared to the referent model
364  :param size: size of the rendered graph
365  """
366  if size is None:
367    size = gum.config["notebook", "default_graph_size"]
368  cmp = GraphicalBNComparator(bn1, bn2)
369  showGraph(cmp.dotDiff(), size)
370
371
372def showInformation(*args, **kwargs):
373  print(
374    "[pyAgrum] pyAgrum.lib.notebook.showInformation is deprecated since 0.20.2. Please use pyAgrum.lib.explain.showInfomation instead."
375  )
376  import pyAgrum.lib.explain as explain
377  explain.showInformation(*args, **kwargs)
378
379
380def showJunctionTree(bn, withNames=True, size=None):
381  """
382  Show a junction tree
383
384  :param bn: the Bayesian network
385  :param boolean withNames: display the variable names or the node id in the clique
386  :param size: size of the rendered graph
387  :return: the representation of the graph
388  """
389  if size is None:
390    size = gum.config["notebook", "default_graph_size"]
391
392  jtg = gum.JunctionTreeGenerator()
393  jt = jtg.junctionTree(bn)
394
395  jt._engine = jtg
396  jtg._model = bn
397
398  if withNames:
399    return showDot(jt.toDotWithNames(bn), size)
400  else:
401    return showDot(jt.toDot(), size)
402
403
404def getJunctionTree(bn, withNames=True, size=None):
405  """
406  get a HTML string for a junction tree (more specifically a join tree)
407
408  :param bn: the Bayesian network
409  :param boolean withNames: display the variable names or the node id in the clique
410  :param size: size of the rendered graph
411  :return: the HTML representation of the graph
412  """
413  if size is None:
414    size = gum.config["notebook", "junctiontree_graph_size"]
415
416  jtg = gum.JunctionTreeGenerator()
417  jt = jtg.junctionTree(bn)
418
419  jt._engine = jtg
420  jtg._model = bn
421
422  if withNames:
423    return getDot(jt.toDotWithNames(bn), size)
424  else:
425    return getDot(jt.toDot(), size)
426
427
428def showProba(p, scale=1.0):
429  """
430  Show a mono-dim Potential
431
432  :param p: the mono-dim Potential
433  :param scale: the scale (zoom)
434  """
435  fig = proba2histo(p, scale)
436  set_matplotlib_formats(gum.config["notebook", "graph_format"])
437  plt.show()
438
439
440def _getMatplotFig(fig):
441  bio = io.BytesIO()  # bytes buffer for the plot
442  fig.savefig(bio, format='png', bbox_inches='tight')  # .canvas.print_png(bio)  # make a png of the plot in the buffer
443
444  # encode the bytes as string using base 64
445  sB64Img = base64.b64encode(bio.getvalue()).decode()
446  res = f'<img src="data:image/png;base64,{sB64Img}\n">'
447  plt.close()
448  return res
449
450
451def getProba(p, scale=1.0):
452  """
453  get a mono-dim Potential as html (png) img
454
455  :param p: the mono-dim Potential
456  :param scale: the scale (zoom)
457  """
458  set_matplotlib_formats(gum.config["notebook", "graph_format"])
459  return _getMatplotFig(proba2histo(p, scale))
460
461
462def showProbaMinMax(pmin, pmax, scale=1.0):
463  """
464  Show a bi-Potential (min,max)
465
466  :param pmin: the mono-dim Potential for min values
467  :param pmax: the mono-dim Potential for max values
468  :param scale: the scale (zoom)
469  """
470  fig = probaMinMaxH(pmin, pmax, scale)
471  set_matplotlib_formats(gum.config["notebook", "graph_format"])
472  plt.show()
473
474
475def getProbaMinMax(pmin, pmax, scale=1.0):
476  """
477  get a bi-Potential (min,max) as html (png) img
478
479  :param pmin: the mono-dim Potential for min values
480  :param pmax: the mono-dim Potential for max values
481  :param scale: the scale (zoom)
482  """
483  set_matplotlib_formats(gum.config["notebook", "graph_format"])
484  return _getMatplotFig(probaMinMaxH(pmin, pmax, scale))
485
486
487def getPosterior(bn, evs, target):
488  """
489  shortcut for proba2histo(gum.getPosterior(bn,evs,target))
490
491  :param bn: the BayesNet
492  :type bn: gum.BayesNet
493  :param evs: map of evidence
494  :type evs: dict(str->int)
495  :param target: name of target variable
496  :type target: str
497  :return: the matplotlib graph
498  """
499  fig = proba2histo(gum.getPosterior(bn, evs=evs, target=target))
500  plt.close()
501  return _getMatplotFig(fig)
502
503
504def showPosterior(bn, evs, target):
505  """
506  shortcut for showProba(gum.getPosterior(bn,evs,target))
507
508  :param bn: the BayesNet
509  :param evs: map of evidence
510  :param target: name of target variable
511  """
512  showProba(gum.getPosterior(bn, evs=evs, target=target))
513
514
515def animApproximationScheme(apsc, scale=np.log10):
516  """
517  show an animated version of an approximation algorithm
518
519  :param apsc: the approximation algorithm
520  :param scale: a function to apply to the figure
521  """
522  f = plt.gcf()
523
524  h = gum.PythonApproximationListener(
525    apsc._asIApproximationSchemeConfiguration()
526  )
527  apsc.setVerbosity(True)
528  apsc.listener = h
529
530  def stopper(x):
531    IPython.display.clear_output(True)
532    plt.title(
533      f"{x} \n Time : {apsc.currentTime()}s | Iterations : {apsc.nbrIterations()} | Epsilon : {apsc.epsilon()}"
534    )
535
536  def progresser(x, y, z):
537    if len(apsc.history()) < 10:
538      plt.xlim(1, 10)
539    else:
540      plt.xlim(1, len(apsc.history()))
541    plt.plot(scale(apsc.history()), 'g')
542    IPython.display.clear_output(True)
543    IPython.display.display(f)
544
545  h.setWhenStop(stopper)
546  h.setWhenProgress(progresser)
547
548
549def showApproximationScheme(apsc, scale=np.log10):
550  if apsc.verbosity():
551    if len(apsc.history()) < 10:
552      plt.xlim(1, 10)
553    else:
554      plt.xlim(1, len(apsc.history()))
555    plt.title(
556      f"Time : {apsc.currentTime()}s | Iterations : {apsc.nbrIterations()} | Epsilon : {apsc.epsilon()}"
557    )
558    plt.plot(scale(apsc.history()), 'g')
559
560
561def showMN(mn, view=None, size=None, nodeColor=None, factorColor=None, edgeWidth=None, edgeColor=None, cmap=None,
562           cmapEdge=None
563           ):
564  """
565  show a Markov network
566
567  :param mn: the markov network
568  :param view: 'graph' | 'factorgraph’ | None (default)
569  :param size: size of the rendered graph
570  :param nodeColor: a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
571  :param factorColor: a function returning a value (beeween 0 and 1) to be shown as a color of factor. (used when view='factorgraph')
572  :param edgeWidth: a edgeMap of values to be shown as width of edges  (used when view='graph')
573  :param edgeColor: a edgeMap of values (between 0 and 1) to be shown as color of edges (used when view='graph')
574  :param cmap: color map to show the colors
575  :param cmapEdge: color map to show the edge color if distinction is needed
576  :return: the graph
577  """
578  if view is None:
579    view = gum.config["notebook", "default_markovnetwork_view"]
580
581  if size is None:
582    size = gum.config["notebook", "default_graph_size"]
583
584  if cmapEdge is None:
585    cmapEdge = cmap
586
587  if view == "graph":
588    dottxt = MN2UGdot(mn, size, nodeColor, edgeWidth,
589                      edgeColor, cmap, cmapEdge
590                      )
591  else:
592    dottxt = MN2FactorGraphdot(mn, size, nodeColor, factorColor, cmapNode=cmap)
593
594  return showGraph(dottxt, size)
595
596
597def showInfluenceDiagram(diag, size=None):
598  """
599  show an influence diagram as a graph
600
601  :param diag: the influence diagram
602  :param size: size of the rendered graph
603  :return: the representation of the influence diagram
604  """
605  if size is None:
606    size = gum.config["influenceDiagram", "default_id_size"]
607
608  return showGraph(ID2dot(diag), size)
609
610
611def getInfluenceDiagram(diag, size=None):
612  """
613  get a HTML string for an influence diagram as a graph
614
615  :param diag: the influence diagram
616  :param size: size of the rendered graph
617  :return: the HTML representation of the influence diagram
618  """
619  if size is None:
620    size = gum.config["influenceDiagram", "default_id_size"]
621
622  return getGraph(ID2dot(diag), size)
623
624
625def showBN(bn, size=None, nodeColor=None, arcWidth=None, arcColor=None, cmap=None, cmapArc=None):
626  """
627  show a Bayesian network
628
629  :param bn: the Bayesian network
630  :param size: size of the rendered graph
631  :param nodeColor: a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
632  :param arcWidth: a arcMap of values to be shown as width of arcs
633  :param arcColor: a arcMap of values (between 0 and 1) to be shown as color of arcs
634  :param cmap: color map to show the colors
635  :param cmapArc: color map to show the arc color if distinction is needed
636  :return: the graph
637  """
638  if size is None:
639    size = gum.config["notebook", "default_graph_size"]
640
641  if cmapArc is None:
642    cmapArc = cmap
643
644  return showGraph(BN2dot(bn, size, nodeColor, arcWidth, arcColor, cmap, cmapArc), size)
645
646
647def showCN(cn, size=None, nodeColor=None, arcWidth=None, arcColor=None, cmap=None, cmapArc=None):
648  """
649  show a credal network
650
651  :param cn: the credal network
652  :param size: size of the rendered graph
653  :param nodeColor: a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
654  :param arcWidth: a arcMap of values to be shown as width of arcs
655  :param arcColor: a arcMap of values (between 0 and 1) to be shown as color of arcs
656  :param cmap: color map to show the colors
657  :param cmapArc: color map to show the arc color if distinction is needed
658  :return: the graph
659  """
660  if size is None:
661    size = gum.config["notebook", "default_graph_size"]
662
663  if cmapArc is None:
664    cmapArc = cmap
665
666  return showGraph(CN2dot(cn, size, nodeColor, arcWidth, arcColor, cmap, cmapArc), size)
667
668
669def getMN(mn, view=None, size=None, nodeColor=None, factorColor=None, edgeWidth=None, edgeColor=None, cmap=None,
670          cmapEdge=None
671          ):
672  """
673  get an HTML string for a Markov network
674
675  :param mn: the markov network
676  :param view: 'graph' | 'factorgraph’ | None (default)
677  :param size: size of the rendered graph
678  :param nodeColor: a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
679  :param factorColor: a function returning a value (beeween 0 and 1) to be shown as a color of factor. (used when view='factorgraph')
680  :param edgeWidth: a edgeMap of values to be shown as width of edges  (used when view='graph')
681  :param edgeColor: a edgeMap of values (between 0 and 1) to be shown as color of edges (used when view='graph')
682  :param cmap: color map to show the colors
683  :param cmapEdge: color map to show the edge color if distinction is needed
684  :return: the graph
685  """
686  if size is None:
687    size = gum.config["notebook", "default_graph_size"]
688
689  if cmapEdge is None:
690    cmapEdge = cmap
691
692  if view is None:
693    view = gum.config["notebook", "default_markovnetwork_view"]
694
695  if view == "graph":
696    dottxt = MN2UGdot(mn, size, nodeColor, edgeWidth,
697                      edgeColor, cmap, cmapEdge
698                      )
699  else:
700    dottxt = MN2FactorGraphdot(mn, size, nodeColor, factorColor, cmapNode=cmap)
701
702  return getGraph(dottxt, size)
703
704
705def getBN(bn, size=None, nodeColor=None, arcWidth=None, arcColor=None, cmap=None, cmapArc=None):
706  """
707  get a HTML string for a Bayesian network
708
709  :param bn: the Bayesian network
710  :param size: size of the rendered graph
711  :param nodeColor: a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
712  :param arcWidth: a arcMap of values to be shown as width of arcs
713  :param arcColor: a arcMap of values (between 0 and 1) to be shown as color of arcs
714  :param cmap: color map to show the colors
715  :param cmapArc: color map to show the arc color if distinction is needed
716
717  :return: the graph
718  """
719  if size is None:
720    size = gum.config["notebook", "default_graph_size"]
721
722  if cmapArc is None:
723    cmapArc = cmap
724
725  return getGraph(BN2dot(bn, size, nodeColor, arcWidth, arcColor, cmap, cmapArc), size)
726
727
728def getCN(cn, size=None, nodeColor=None, arcWidth=None, arcColor=None, cmap=None, cmapArc=None):
729  """
730  get a HTML string for a credal network
731
732  :param cn: the credal network
733  :param size: size of the rendered graph
734  :param nodeColor: a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
735  :param arcWidth: a arcMap of values to be shown as width of arcs
736  :param arcColor: a arcMap of values (between 0 and 1) to be shown as color of arcs
737  :param cmap: color map to show the colors
738  :param cmapArc: color map to show the arc color if distinction is needed
739
740  :return: the graph
741  """
742  if size is None:
743    size = gum.config["notebook", "default_graph_size"]
744
745  if cmapArc is None:
746    cmapArc = cmap
747
748  return getGraph(CN2dot(cn, size, nodeColor, arcWidth, arcColor, cmap, cmapArc), size)
749
750
751def _get_showInference(model, engine=None, evs=None, targets=None, size=None,
752                       nodeColor=None, factorColor=None, arcWidth=None,
753                       arcColor=None, cmap=None, cmapArc=None, graph=None, view=None
754                       ):
755  if size is None:
756    size = gum.config["notebook", "default_graph_inference_size"]
757
758  if evs is None:
759    evs = {}
760
761  if targets is None:
762    targets = {}
763
764  if isinstance(model, gum.BayesNet):
765    if engine is None:
766      engine = gum.LazyPropagation(model)
767    return BNinference2dot(model, size=size, engine=engine, evs=evs, targets=targets, nodeColor=nodeColor,
768                           arcWidth=arcWidth,
769                           arcColor=arcColor,
770                           cmapNode=cmap, cmapArc=cmapArc
771                           )
772  elif isinstance(model, gum.MarkovNet):
773    if view is None:
774      view = gum.config["notebook", "default_markovnetwork_view"]
775      if engine is None:
776        engine = gum.ShaferShenoyMNInference(model)
777
778      if view == "graph":
779        return MNinference2UGdot(model, size=size, engine=engine, evs=evs, targets=targets, nodeColor=nodeColor,
780                                 factorColor=factorColor,
781                                 arcWidth=arcWidth, arcColor=arcColor, cmapNode=cmap, cmapArc=cmapArc
782                                 )
783      else:
784        return MNinference2FactorGraphdot(model, size=size, engine=engine, evs=evs, targets=targets,
785                                          nodeColor=nodeColor,
786                                          factorColor=factorColor, cmapNode=cmap
787                                          )
788  elif isinstance(model, gum.InfluenceDiagram):
789    if engine is None:
790      engine = gum.ShaferShenoyLIMIDInference(model)
791    return LIMIDinference2dot(model, size=size, engine=engine, evs=evs, targets=targets)
792  elif isinstance(model, gum.CredalNet):
793    if engine is None:
794      engine = gum.CNMonteCarloSampling(model)
795    return CNinference2dot(model, size=size, engine=engine, evs=evs, targets=targets, nodeColor=nodeColor,
796                           arcWidth=arcWidth, arcColor=arcColor, cmapNode=cmap
797                           )
798  else:
799    raise gum.InvalidArgument(
800      "Argument model should be a PGM (BayesNet, MarkovNet or Influence Diagram"
801    )
802
803
804def showInference(model, **kwargs):
805  """
806  show pydot graph for an inference in a notebook
807
808  :param GraphicalModel model: the model in which to infer (pyAgrum.BayesNet, pyAgrum.MarkovNet or pyAgrum.InfluenceDiagram)
809  :param gum.Inference engine: inference algorithm used. If None, gum.LazyPropagation will be used for BayesNet, gum.ShaferShenoy for gum.MarkovNet and gum.ShaferShenoyLIMIDInference for gum.InfluenceDiagram.
810  :param dictionnary evs: map of evidence
811  :param set targets: set of targets
812  :param string size: size of the rendered graph
813  :param nodeColor: a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
814  :param factorColor: a nodeMap of values (between 0 and 1) to be shown as color of factors (in MarkovNet representation)
815  :param arcWidth: a arcMap of values to be shown as width of arcs
816  :param arcColor: a arcMap of values (between 0 and 1) to be shown as color of arcs
817  :param cmap: color map to show the color of nodes and arcs
818  :param cmapArc: color map to show the vals of Arcs.
819  :param graph: only shows nodes that have their id in the graph (and not in the whole BN)
820  :param view: graph | factorgraph | None (default) for Markov network
821  :return: the desired representation of the inference
822  """
823  if "size" in kwargs:
824    size = kwargs['size']
825  else:
826    size = gum.config["notebook", "default_graph_inference_size"]
827
828  showGraph(_get_showInference(model, **kwargs), size)
829
830
831def getInference(model, **kwargs):
832  """
833  get a HTML string for an inference in a notebook
834
835  :param GraphicalModel model: the model in which to infer (pyAgrum.BayesNet, pyAgrum.MarkovNet or
836          pyAgrum.InfluenceDiagram)
837  :param gum.Inference engine: inference algorithm used. If None, gum.LazyPropagation will be used for BayesNet,
838          gum.ShaferShenoy for gum.MarkovNet and gum.ShaferShenoyLIMIDInference for gum.InfluenceDiagram.
839  :param dictionnary evs: map of evidence
840  :param set targets: set of targets
841  :param string size: size of the rendered graph
842  :param nodeColor: a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
843  :param factorColor: a nodeMap of values (between 0 and 1) to be shown as color of factors (in MarkovNet representation)
844  :param arcWidth: a arcMap of values to be shown as width of arcs
845  :param arcColor: a arcMap of values (between 0 and 1) to be shown as color of arcs
846  :param cmap: color map to show the color of nodes and arcs
847  :param cmapArc: color map to show the vals of Arcs.
848  :param graph: only shows nodes that have their id in the graph (and not in the whole BN)
849  :param view: graph | factorgraph | None (default) for Markov network
850
851  :return: the desired representation of the inference
852  """
853  if "size" in kwargs:
854    size = kwargs['size']
855  else:
856    size = gum.config["notebook", "default_graph_inference_size"]
857
858  grinf = _get_showInference(model, **kwargs)
859  return getGraph(grinf, size)
860
861
862def _reprPotential(pot, digits=None, withColors=True, varnames=None, asString=False):
863  """
864  return a representation of a gum.Potential as a HTML table.
865  The first dimension is special (horizontal) due to the representation of conditional probability table
866
867  :param pot: the potential to get
868  :param digits: number of digits to show
869  :param withColors: bgcolor for proba cells or not
870  :param varnames: the aliases for variables name in the table
871  :param asString: display the table or a HTML string
872
873  :return: the representation
874  """
875  from IPython.core.display import HTML
876
877
878  r0, g0, b0 = gumcols.hex2rgb(gum.config['notebook', 'potential_color_0'])
879  r1, g1, b1 = gumcols.hex2rgb(gum.config['notebook', 'potential_color_1'])
880
881  def _rgb(r, g, b):
882    return '#%02x%02x%02x' % (r, g, b)
883
884  def _mkCell(val):
885    s = "<td style='"
886    if withColors and (0 <= val <= 1):
887      r = int(r0 + val * (r1 - r0))
888      g = int(g0 + val * (g1 - g0))
889      b = int(b0 + val * (b1 - b0))
890
891      tx = gumcols.rgb2brightness(r,g,b)
892
893      s += "color:" + tx + ";background-color:" + _rgb(r, g, b) + ";"
894    s += "text-align:right;'>{:." + str(digits) + "f}</td>"
895    return s.format(val)
896
897  if digits is None:
898    digits = gum.config['notebook', 'potential_visible_digits']
899
900  if gum.config["notebook", "potential_with_colors"] == "False":
901    withColors = False
902
903  html = list()
904  html.append('<table style="border:1px solid black;">')
905  if pot.empty():
906    html.append(
907      "<tr><th>&nbsp;</th></tr>"
908    )
909    html.append("<tr>" + _mkCell(pot.get(gum.Instantiation())) + "</tr>")
910  else:
911    if varnames is not None and len(varnames) != pot.nbrDim():
912      raise ValueError(
913        f"varnames contains {len(varnames)} value(s) instead of the needed {pot.nbrDim()} value(s)."
914      )
915
916    nparents = pot.nbrDim() - 1
917    var = pot.variable(0)
918    varname = var.name() if varnames == None else varnames[0]
919
920    # first line
921    if nparents > 0:
922      html.append(f"""<tr><th colspan='{nparents}'></th>
923      <th colspan='{var.domainSize()}' style='border:1px solid black;color:black;background-color:#808080;'><center>{varname}</center>
924      </th></tr>"""
925                  )
926    else:
927      html.append(f"""<tr style='border:1px solid black;color:black;background-color:#808080'>
928      <th colspan='{var.domainSize()}'><center>{varname}</center></th></tr>"""
929                  )
930
931    # second line
932    s = "<tr>"
933    if nparents > 0:
934      # parents order
935      if gum.config["notebook", "potential_parent_values"] == "revmerge":
936        pmin, pmax, pinc = nparents - 1, 0 - 1, -1
937      else:
938        pmin, pmax, pinc = 0, nparents, 1
939      for par in range(pmin, pmax, pinc):
940        parent = pot.var_names[par] if varnames is None else varnames[par]
941        s += f"<th style='border:1px solid black;color:black;background-color:#808080'><center>{parent}</center></th>"
942
943    for label in var.labels():
944      s += f"""<th style='border:1px solid black;border-bottom-style: double;color:black;background-color:#BBBBBB'>
945      <center>{label}</center></th>"""
946    s += '</tr>'
947
948    html.append(s)
949
950    inst = gum.Instantiation(pot)
951    off = 1
952    offset = dict()
953    for i in range(1, nparents + 1):
954      offset[i] = off
955      off *= inst.variable(i).domainSize()
956
957    inst.setFirst()
958    while not inst.end():
959      s = "<tr>"
960      # parents order
961      if gum.config["notebook", "potential_parent_values"] == "revmerge":
962        pmin, pmax, pinc = 1, nparents + 1, 1
963      else:
964        pmin, pmax, pinc = nparents, 0, -1
965      for par in range(pmin, pmax, pinc):
966        label = inst.variable(par).label(inst.val(par))
967        if par == 1 or gum.config["notebook", "potential_parent_values"] == "nomerge":
968          s += f"<th style='border:1px solid black;color:black;background-color:#BBBBBB'><center>{label}</center></th>"
969        else:
970          if sum([inst.val(i) for i in range(1, par)]) == 0:
971            s += f"""<th style='border:1px solid black;color:black;background-color:#BBBBBB;' rowspan = '{offset[par]}'>
972            <center>{label}</center></th>"""
973      for j in range(pot.variable(0).domainSize()):
974        s += _mkCell(pot.get(inst))
975        inst.inc()
976      s += "</tr>"
977      html.append(s)
978
979  html.append("</table>")
980
981  if asString:
982    return "\n".join(html)
983  else:
984    return IPython.display.HTML("".join(html))
985
986
987def __isKindOfProba(pot):
988  """
989  check if pot is a joint proba or a CPT
990  :param pot: the potential
991  :return: True or False
992  """
993  epsilon = 1e-5
994  if pot.min() < -epsilon:
995    return False
996  if pot.max() > 1 + epsilon:
997    return False
998
999  # is it a joint proba ?
1000  if abs(pot.sum() - 1) < epsilon:
1001    return True
1002
1003  # marginal and then not proba (because of the test just above)
1004  if pot.nbrDim() < 2:
1005    return False
1006
1007  # is is a CPT ?
1008  q = pot.margSumOut([pot.variable(0).name()])
1009  if abs(q.max() - 1) > epsilon:
1010    return False
1011  if abs(q.min() - 1) > epsilon:
1012    return False
1013  return True
1014
1015
1016def showPotential(pot, digits=None, withColors=None, varnames=None):
1017  """
1018  show a gum.Potential as a HTML table.
1019  The first dimension is special (horizontal) due to the representation of conditional probability table
1020
1021  :param gum.Potential pot: the potential to get
1022  :param int digits: number of digits to show
1023  :param: boolean withColors : bgcolor for proba cells or not
1024  :param list of strings varnames: the aliases for variables name in the table
1025  :return: the display of the potential
1026  """
1027  if withColors is None:
1028    if gum.config["notebook", "potential_with_colors"] == "False":
1029      withColors = False
1030    else:
1031      withColors = __isKindOfProba(pot)
1032
1033  s = _reprPotential(pot, digits, withColors, varnames, asString=False)
1034  IPython.display.display(s)
1035
1036
1037def getPotential(pot, digits=None, withColors=None, varnames=None):
1038  """
1039  return a HTML string of a gum.Potential as a HTML table.
1040  The first dimension is special (horizontal) due to the representation of conditional probability table
1041
1042  :param gum.Potential pot: the potential to get
1043  :param int digits: number of digits to show
1044  :param: boolean withColors : bgcolor for proba cells or not
1045  :param list of strings varnames: the aliases for variables name in the table
1046  :return: the HTML string
1047  """
1048  if withColors is None:
1049    if gum.config["notebook", "potential_with_colors"] == "False":
1050      withColors = False
1051    else:
1052      withColors = __isKindOfProba(pot)
1053
1054  return _reprPotential(pot, digits, withColors, varnames, asString=True)
1055
1056
1057def getSideBySide(*args, **kwargs):
1058  """
1059  create an HTML table for args as string (using string, _repr_html_() or str())
1060
1061  :param args: HMTL fragments as string arg, arg._repr_html_() or str(arg)
1062  :param captions: list of strings (captions)
1063  :return: a string representing the table
1064  """
1065  vals = {'captions', 'valign'}
1066  if not set(kwargs.keys()).issubset(vals):
1067    raise TypeError(f"sideBySide() got unexpected keyword argument(s) : '{set(kwargs.keys()).difference(vals)}'")
1068
1069  if 'captions' in kwargs:
1070    captions = kwargs['captions']
1071  else:
1072    captions = None
1073
1074  if 'valign' in kwargs:
1075    v_align = 'vertical-align:' + kwargs['valign'] + ';'
1076  else:
1077    v_align = ""
1078
1079  s = '<table style="border-style: hidden; border-collapse: collapse;" width="100%">'
1080
1081  def reprHTML(s):
1082    if isinstance(s, str):
1083      return s
1084    elif hasattr(s, '_repr_html_'):
1085      return s._repr_html_()
1086    else:
1087      return str(s)
1088
1089  s += '<tr><td style="border-top:hidden;border-bottom:hidden;' + v_align + '"><div align="center" style="' + v_align \
1090       + '">'
1091  s += (
1092     '</div></td><td style="border-top:hidden;border-bottom:hidden;' + v_align + '"><div align="center" style="' +
1093     v_align + '">').join(
1094    [reprHTML(arg)
1095     for arg in args]
1096  )
1097  s += '</div></td></tr>'
1098
1099  if captions is not None:
1100    s += '<tr><td style="border-top:hidden;border-bottom:hidden;"><div align="center"><small>'
1101    s += '</small></div></td><td style="border-top:hidden;border-bottom:hidden;"><div align="center"><small>'.join(
1102      captions
1103    )
1104    s += '</small></div></td></tr>'
1105
1106  s += '</table>'
1107  return s
1108
1109
1110def sideBySide(*args, **kwargs):
1111  """
1112  display side by side args as HMTL fragment (using string, _repr_html_() or str())
1113
1114  :param args: HMTL fragments as string arg, arg._repr_html_() or str(arg)
1115  :param captions: list of strings (captions)
1116  """
1117  IPython.display.display(IPython.display.HTML(getSideBySide(*args, **kwargs)))
1118
1119
1120def getInferenceEngine(ie, inferenceCaption):
1121  """
1122  display an inference as a BN+ lists of hard/soft evidence and list of targets
1123
1124  :param gum.InferenceEngine ie: inference engine
1125  :param string inferenceCaption: title for caption
1126
1127  """
1128  t = '<div align="left"><ul>'
1129  if ie.nbrHardEvidence() > 0:
1130    t += "<li><b>hard evidence</b><br/>"
1131    t += ", ".join([ie.BN().variable(n).name()
1132                    for n in ie.hardEvidenceNodes()]
1133                   )
1134    t += "</li>"
1135  if ie.nbrSoftEvidence() > 0:
1136    t += "<li><b>soft evidence</b><br/>"
1137    t += ", ".join([ie.BN().variable(n).name()
1138                    for n in ie.softEvidenceNodes()]
1139                   )
1140    t += "</li>"
1141  if ie.nbrTargets() > 0:
1142    t += "<li><b>target(s)</b><br/>"
1143    if ie.nbrTargets() == ie.BN().size():
1144      t += " all"
1145    else:
1146      t += ", ".join([ie.BN().variable(n).name() for n in ie.targets()])
1147    t += "</li>"
1148
1149  if hasattr(ie, 'nbrJointTargets'):
1150    if ie.nbrJointTargets() > 0:
1151      t += "<li><b>Joint target(s)</b><br/>"
1152      t += ", ".join(['['
1153                      + (", ".join([ie.BN().variable(n).name()
1154                                    for n in ns]
1155                                   ))
1156                      + ']' for ns in ie.jointTargets()]
1157                     )
1158      t += "</li>"
1159  t += '</ul></div>'
1160  return getSideBySide(getBN(ie.BN()), t, captions=[inferenceCaption, "Evidence and targets"])
1161
1162
1163def getJT(jt, size=None):
1164  if gum.config["notebook", "junctiontree_with_names"] == "True":
1165    def cliqlabels(c):
1166      return " ".join(
1167        sorted([model.variable(n).name() for n in jt.clique(c)])
1168      )
1169
1170    def cliqnames(
1171       c
1172    ):
1173      return "-".join(sorted([model.variable(n).name() for n in jt.clique(c)]))
1174
1175    def seplabels(c1, c2):
1176      return " ".join(
1177        sorted([model.variable(n).name() for n in jt.separator(c1, c2)])
1178      )
1179
1180    def sepnames(c1, c2):
1181      return cliqnames(c1) + '+' + cliqnames(c2)
1182  else:
1183    def cliqlabels(c):
1184      return " ".join([str(n) for n in sorted(jt.clique(c))])
1185
1186    def cliqnames(c):
1187      return "-".join([str(n) for n in sorted(jt.clique(c))])
1188
1189    def seplabels(c1, c2):
1190      return " ".join(
1191        [str(n) for n in sorted(jt.separator(c1, c2))]
1192      )
1193
1194    def sepnames(c1, c2):
1195      return cliqnames(c1) + '^' + cliqnames(c2)
1196
1197  model = jt._engine._model
1198  name = model.propertyWithDefault(
1199    "name", str(type(model)).split(".")[-1][:-2]
1200  )
1201  graph = dot.Dot(graph_type='graph', graph_name=name, bgcolor="transparent")
1202  for c in jt.nodes():
1203    graph.add_node(dot.Node('"' + cliqnames(c) + '"',
1204                            label='"' + cliqlabels(c) + '"',
1205                            style="filled",
1206                            fillcolor=gum.config["notebook",
1207                                                 "junctiontree_clique_bgcolor"],
1208                            fontcolor=gum.config["notebook",
1209                                                 "junctiontree_clique_fgcolor"],
1210                            fontsize=gum.config["notebook", "junctiontree_clique_fontsize"]
1211                            )
1212                   )
1213  for c1, c2 in jt.edges():
1214    graph.add_node(dot.Node('"' + sepnames(c1, c2) + '"',
1215                            label='"' + seplabels(c1, c2) + '"',
1216                            style="filled",
1217                            shape="box", width="0", height="0", margin="0.02",
1218                            fillcolor=gum.config["notebook",
1219                                                 "junctiontree_separator_bgcolor"],
1220                            fontcolor=gum.config["notebook",
1221                                                 "junctiontree_separator_fgcolor"],
1222                            fontsize=gum.config["notebook", "junctiontree_separator_fontsize"]
1223                            )
1224                   )
1225  for c1, c2 in jt.edges():
1226    graph.add_edge(dot.Edge('"' + cliqnames(c1) +
1227                            '"', '"' + sepnames(c1, c2) + '"'
1228                            )
1229                   )
1230    graph.add_edge(dot.Edge('"' + sepnames(c1, c2) +
1231                            '"', '"' + cliqnames(c2) + '"'
1232                            )
1233                   )
1234
1235  if size is None:
1236    size = gum.config["notebook", "default_graph_size"]
1237  graph.set_size(gum.config["notebook", "junctiontree_graph_size"])
1238
1239  return graph.to_string()
1240
1241
1242def getCliqueGraph(cg, size=None):
1243  """get a representation for clique graph. Special case for junction tree
1244  (clique graph with an attribute `_engine`)
1245
1246  Args:
1247      cg (gum.CliqueGraph): the clique graph (maybe junction tree for a _model) to
1248                            represent
1249  """
1250  if hasattr(cg, "_engine"):
1251    return getDot(getJT(cg), size)
1252  else:
1253    return getDot(cg.toDot())
1254
1255
1256def show(model, **kwargs):
1257  """
1258  propose a (visual) representation of a model in a notebook
1259
1260  :param GraphicalModel model: the model to show (pyAgrum.BayesNet, pyAgrum.MarkovNet, pyAgrum.InfluenceDiagram or pyAgrum.Potential)
1261
1262  :param int size: optional size for the graphical model (no effect for Potential)
1263  """
1264  if isinstance(model, gum.BayesNet):
1265    showBN(model, **kwargs)
1266  elif isinstance(model, gum.MarkovNet):
1267    showMN(model, **kwargs)
1268  elif isinstance(model, gum.InfluenceDiagram):
1269    showInfluenceDiagram(model, **kwargs)
1270  elif isinstance(model, gum.CredalNet):
1271    showCN(model, **kwargs)
1272  elif isinstance(model, gum.Potential):
1273    showPotential(model)
1274  elif hasattr(model, "toDot"):
1275    showDot(dot.graph_from_dot_data(model.toDot()), **kwargs)
1276  else:
1277    raise gum.InvalidArgument(
1278      "Argument model should be a PGM (BayesNet, MarkovNet, Influence Diagram or Potential or ..."
1279    )
1280
1281
1282def export(model, filename, **kwargs):
1283  """
1284  export the graphical representation of the model in filename (png, pdf,etc.)
1285
1286  :param GraphicalModel model: the model to show (pyAgrum.BayesNet, pyAgrum.MarkovNet, pyAgrum.InfluenceDiagram or pyAgrum.Potential)
1287  :param str filename: the name of the resulting file (suffix in ['pdf', 'png', 'fig', 'jpg', 'svg', 'ps'])
1288  """
1289  format = filename.split(".")[-1]
1290  if format not in ['pdf', 'png', 'fig', 'jpg', 'svg', 'ps']:
1291    raise Exception(
1292      f"{filename} in not a correct filename for export : extension '{format}' not in [pdf,png,fig,jpg,svg]."
1293    )
1294
1295  if isinstance(model, gum.BayesNet):
1296    fig = BN2dot(model, **kwargs)
1297  elif isinstance(model, gum.MarkovNet):
1298    if gum.config["notebook", "default_markovnetwork_view"] == "graph":
1299      fig = MN2UGdot(model, **kwargs)
1300    else:
1301      fig = MN2FactorGraphdot(model, **kwargs)
1302  elif isinstance(model, gum.InfluenceDiagram):
1303    fig = ID2dot(model, **kwargs)
1304  elif isinstance(model, gum.CredalNet):
1305    fig = CN2dot(model, **kwargs)
1306  elif hasattr(model, "toDot"):
1307    fig = dot.graph_from_dot_data(model.toDot(), **kwargs)
1308  else:
1309    raise gum.InvalidArgument(
1310      "Argument model should be a PGM (BayesNet, MarkovNet or Influence Diagram"
1311    )
1312  fig.write(filename, format=format)
1313
1314
1315def exportInference(model, filename, **kwargs):
1316  """
1317  the graphical representation of an inference in a notebook
1318
1319  :param GraphicalModel model: the model in which to infer (pyAgrum.BayesNet, pyAgrum.MarkovNet or
1320          pyAgrum.InfluenceDiagram)
1321  :param str filename: the name of the resulting file (suffix in ['pdf', 'png', 'ps'])
1322  :param gum.Inference engine: inference algorithm used. If None, gum.LazyPropagation will be used for BayesNet,
1323          gum.ShaferShenoy for gum.MarkovNet and gum.ShaferShenoyLIMIDInference for gum.InfluenceDiagram.
1324  :param dictionnary evs: map of evidence
1325  :param set targets: set of targets
1326  :param string size: size of the rendered graph
1327  :param nodeColor: a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
1328  :param factorColor: a nodeMap of values (between 0 and 1) to be shown as color of factors (in MarkovNet representation)
1329  :param arcWidth: a arcMap of values to be shown as width of arcs
1330  :param arcColor: a arcMap of values (between 0 and 1) to be shown as color of arcs
1331  :param cmap: color map to show the color of nodes and arcs
1332  :param cmapArc: color map to show the vals of Arcs.
1333  :param graph: only shows nodes that have their id in the graph (and not in the whole BN)
1334  :param view: graph | factorgraph | None (default) for Markov network
1335  :return: the desired representation of the inference
1336  """
1337  format = filename.split(".")[-1]
1338  if format not in ['pdf', 'png', 'ps']:
1339    raise Exception(
1340      f"{filename} in not a correct filename for export : extension '{format}' not in [pdf,png,ps]."
1341    )
1342
1343  import cairosvg
1344
1345  if "size" in kwargs:
1346    size = kwargs['size']
1347  else:
1348    size = gum.config["notebook", "default_graph_inference_size"]
1349
1350  svgtxt = _reprGraph(
1351    _get_showInference(model, **kwargs), size=size, asString=True, format="svg"
1352  )
1353
1354  if format == "pdf":
1355    cairosvg.svg2pdf(bytestring=svgtxt, write_to=filename)
1356  elif format == "png":
1357    cairosvg.svg2png(bytestring=svgtxt, write_to=filename)
1358  else:  # format=="ps"
1359    cairosvg.svg2ps(bytestring=svgtxt, write_to=filename)
1360
1361
1362def _update_config():
1363  # hook to control some parameters for notebook when config changes
1364  mpl.rcParams['figure.facecolor'] = gum.config["notebook", "figure_facecolor"]
1365  set_matplotlib_formats(gum.config["notebook", "graph_format"])
1366
1367
1368# check if an instance of ipython exists
1369try:
1370  get_ipython
1371except NameError as e:
1372  import warnings
1373
1374  warnings.warn("""
1375  ** pyAgrum.lib.notebook has to be imported from an IPython's instance (mainly notebook).
1376  """
1377                )
1378else:
1379  gum.config.add_hook(_update_config)
1380  gum.config.run_hooks()
1381
1382  # adding _repr_html_ to some pyAgrum classes !
1383  gum.BayesNet._repr_html_ = lambda self: getBN(self)
1384  gum.BayesNetFragment._repr_html_ = lambda self: getBN(self)
1385  gum.MarkovNet._repr_html_ = lambda self: getMN(self)
1386  gum.BayesNetFragment._repr_html_ = lambda self: getBN(self)
1387  gum.InfluenceDiagram._repr_html_ = lambda self: getInfluenceDiagram(self)
1388  gum.CredalNet._repr_html_ = lambda self: getCN(self)
1389
1390  gum.CliqueGraph._repr_html_ = lambda self: getCliqueGraph(self)
1391
1392  gum.Potential._repr_html_ = lambda self: getPotential(self)
1393  gum.LazyPropagation._repr_html_ = lambda self: getInferenceEngine(
1394    self, "Lazy Propagation on this BN"
1395  )
1396
1397  gum.UndiGraph._repr_html_ = lambda self: getDot(self.toDot())
1398  gum.DiGraph._repr_html_ = lambda self: getDot(self.toDot())
1399  gum.MixedGraph._repr_html_ = lambda self: getDot(self.toDot())
1400  gum.DAG._repr_html_ = lambda self: getDot(self.toDot())
1401  gum.EssentialGraph._repr_html_ = lambda self: getDot(self.toDot())
1402  gum.MarkovBlanket._repr_html_ = lambda self: getDot(self.toDot())
1403