1"""
2The purpose of this module is to provide basic tools for dealing with dynamic Bayesian Network (and inference) : modeling, visualisation, inference.
3"""
4
5# (c) Copyright by Pierre-Henri Wuillemin, UPMC, 2017
6#   (pierre-henri.wuillemin@lip6.fr)
7
8# Permission to use, copy, modify, and distribute this
9# software and its documentation for any purpose and
10# without fee or royalty is hereby granted, provided
11# that the above copyright notice appear in all copies
12# and that both that copyright notice and this permission
13# notice appear in supporting documentation or portions
14# thereof, including modifications, that you make.
15
16# THE AUTHOR P.H. WUILLEMIN  DISCLAIMS ALL WARRANTIES
17# WITH REGARD TO THIS SOFTWARE, INCLUDING ALL IMPLIED
18# WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT
19# SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT
20# OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER
21# RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER
22# IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
23# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE
24# OR PERFORMANCE OF THIS SOFTWARE!
25
26import numpy as np
27import pydotplus as dot
28
29import matplotlib.pyplot as plt
30from matplotlib.patches import Rectangle
31
32import pyAgrum as gum
33
34noTimeCluster = "void"
35
36
37def _splitName(name):
38  """
39  By convention, name of dynamic variable "X" in dBN may be
40    - "X0" for timeslice 0 both in unrolled BN and in 2TBN
41    - "Xt" for timeslice t in a 2TBN
42    - "X"+str(i) for timelice i with integer i>0 in unrolled BN
43    - other naes are not in a timeslice
44  @argument name : str (name of the dynamic variable)
45  @return static_name,timeslice with timeslice =noTimeCluster,"t" or str(i)
46  """
47  if name[-1] == "t":
48    return name[:-1], "t"
49  i = len(name) - 1
50  if not name[i].isdigit():
51    return name, noTimeCluster
52
53  while name[i].isdigit():
54    if i == 0:
55      return name, noTimeCluster
56    i -= 1
57
58  return name[:i + 1], name[i + 1:]
59
60
61def _isInFirstTimeSlice(name):
62  """
63  @return true if there is a 0 at the end of name
64  """
65  return name[-1] == "0"
66
67
68def _isInSecondTimeSlice(name):
69  """
70  @return true if there is a t at the end of name
71  """
72  return name[-1] == "t"
73
74
75def _isInNoTimeSlice(name):
76  return name[-1] not in ["0", "t"]
77
78
79def realNameFrom2TBNname(name, ts):
80  """
81  @return dynamic name from static name and timeslice (no check)
82  """
83  return f"{name[:-1]}{ts}" if not _isInNoTimeSlice(name) else name
84
85
86def getTimeSlicesRange(dbn):
87  """
88  get the range and (name,radical) of each variables
89
90  :param dbn: a 2TBN or an unrolled BN
91  :return: all the timeslice of a dbn
92
93  e.g.
94  ['0','t'] for a classic 2TBN
95  range(T) for a classic unrolled BN
96  """
97  timeslices = {}
98
99  for i in dbn.nodes():
100    n = dbn.variable(i).name()
101    label, ts = _splitName(n)
102    if ts in timeslices:
103      timeslices[ts].append((n, label))
104    else:
105      timeslices[ts] = [(n, label)]
106
107  return timeslices
108
109
110def is2TBN(bn):
111  """
112  Check if bn is a 2 TimeSlice Bayesian network
113
114  :param bn: the Bayesian network
115  :return: True if the BN is syntaxically correct to be a 2TBN
116  """
117  ts = getTimeSlicesRange(bn)
118  if not set(ts.keys()) <= {noTimeCluster, "0", "t"}:
119    return False, "Some variables are not correctly suffixed."
120
121  domainSizes = dict()
122  for name, radical in ts['t']:
123    domainSizes[radical] = bn.variable(name).domainSize()
124
125  res = ""
126  for name, radical in ts['0']:
127    if radical in domainSizes:
128      if domainSizes[radical] != bn.variable(name).domainSize():
129        res = f"\n - for variables {name}/{radical}t"
130
131  if res != "":
132    return False, "Domain size mismatch : " + res
133
134  return True, ""
135
136
137def _TimeSlicesToDot(dbn):
138  """
139  Try to correctly represent dBN and 2TBN in dot format
140  """
141  timeslices = getTimeSlicesRange(dbn)
142
143  # dynamic member makes pylink unhappy
144  # pylint: disable=no-member
145  g = dot.Dot(graph_type='digraph')
146  g.set_rankdir("TD")
147  g.set_splines("ortho")
148
149  for k in sorted(timeslices.keys(), key=lambda x: -1 if x == noTimeCluster else 1e8 if x == 't' else int(x)):
150    if k != noTimeCluster:
151      cluster = dot.Cluster(k, label=f"Time slice {k}", bgcolor="#DDDDDD", rankdir="TD")
152      g.add_subgraph(cluster)
153    else:
154      cluster = g  # small trick to add in graph variable in no timeslice
155    for (n, label) in sorted(timeslices[k]):
156      cluster.add_node(dot.Node('"' + n + '"', label='"' + label + '"', style='filled',
157                                color='#000000', fillcolor='white'
158                                )
159                       )
160
161  for tail, head in dbn.arcs():
162    g.add_edge(dot.Edge('"' + dbn.variable(tail).name() + '"',
163                        '"' + dbn.variable(head).name() + '"',
164                        constraint=False, color="blue"
165                        )
166               )
167
168  for k in sorted(timeslices.keys(), key=lambda x: -1 if x == noTimeCluster else 1e8 if x == 't' else int(x)):
169    if k != noTimeCluster:
170      prec = None
171      for (n, label) in sorted(timeslices[k]):
172        if prec is not None:
173          g.add_edge(dot.Edge('"' + prec + '"',
174                              '"' + n + '"',
175                              style="invis"
176                              )
177                     )
178        prec = n
179
180  return g
181
182
183def showTimeSlices(dbn, size=None):
184  """
185  Try to correctly display dBN and 2TBN
186
187  :param dbn: the dynamic BN
188  :param size: size of the fig
189  """
190
191  # jupyter notebooks is optional
192  # pylint: disable=import-outside-toplevel
193  from pyAgrum.lib.notebook import showGraph
194  if size is None:
195    size = gum.config["dynamicBN", "default_graph_size"]
196
197  showGraph(_TimeSlicesToDot(dbn), size)
198
199
200def getTimeSlices(dbn, size=None):
201  """
202  Try to correctly represent dBN and 2TBN as an HTML string
203
204  :param dbn: the dynamic BN
205  :param size: size of the fig
206  """
207  # jupyter notebooks is optional
208  # pylint: disable=import-outside-toplevel
209  from pyAgrum.lib.notebook import getGraph
210  if size is None:
211    size = gum.config["dynamicBN", "default_graph_size"]
212
213  return getGraph(_TimeSlicesToDot(dbn), size)
214
215
216def unroll2TBN(dbn, nbr):
217  """
218  unroll a 2TBN given the nbr of timeslices
219
220  :param dbn: the dBN
221  :param nbr: the number of timeslice
222
223  :return: unrolled BN from a 2TBN and the nbr of timeslices
224  """
225  ts = getTimeSlicesRange(dbn)
226  if not {noTimeCluster, "0", "t"}.issuperset(ts.keys()) and {"0", "t"}.issubset(ts.keys()):
227    raise TypeError("unroll2TBN needs a 2-TimeSlice BN")
228
229  bn = gum.BayesNet()
230
231  # variable creation
232  for dbn_id in dbn.nodes():
233    name = dbn.variable(dbn_id).name()
234    if _isInNoTimeSlice(name):
235      bn.add(dbn.variable(dbn_id))
236    elif _isInFirstTimeSlice(name):
237      # create a clone of the variable in the new bn
238      bn.add(dbn.variable(dbn_id))
239    else:
240      for ts in range(1, nbr):
241        # create a clone of the variable in the new bn
242        nid = bn.add(dbn.variable(dbn_id))
243        bn.changeVariableName(nid, realNameFrom2TBNname(
244          name, ts
245        )
246                              )  # create the true name
247
248  # add parents
249  # the main pb : to have the same order for parents w.r.t the order in 2TBN
250  for dbn_id in dbn.nodes():
251    name = dbn.variable(dbn_id).name()
252    # right order for parents
253    lvarnames = dbn.cpt(dbn_id).var_names
254    lvarnames.pop()
255    lvarnames.reverse()
256
257    for name_parent in lvarnames:
258      if not _isInSecondTimeSlice(name):
259        if not _isInSecondTimeSlice(name_parent):
260          bn.addArc(bn.idFromName(name_parent), bn.idFromName(name))
261        else:
262          if _isInFirstTimeSlice(name):
263            raise TypeError("An arc from timeslice t to timeslice is impossible in dBN")
264          for ts in range(1, nbr):
265            new_name_parent = realNameFrom2TBNname(name_parent, ts)  # current TimeSlice
266            bn.addArc(bn.idFromName(new_name_parent), bn.idFromName(name))
267      else:
268        for ts in range(1, nbr):
269          if _isInFirstTimeSlice(name_parent):
270            new_name_parent = realNameFrom2TBNname(name_parent, ts - 1)  # last TimeSlice
271          else:
272            new_name_parent = realNameFrom2TBNname(name_parent, ts)  # current TimeSlice
273          new_name = realNameFrom2TBNname(name, ts)  # necessary current TimeSlice
274          bn.addArc(bn.idFromName(new_name_parent), bn.idFromName(new_name))
275
276  # potential creation
277  for dbn_id in dbn.nodes():
278    name = dbn.variable(dbn_id).name()
279    if not _isInSecondTimeSlice(name):
280      bn.cpt(bn.idFromName(name))[:] = dbn.cpt(dbn_id)[:]
281    else:
282      for ts in range(1, nbr):
283        bn.cpt(bn.idFromName(realNameFrom2TBNname(name, ts)))[:] = dbn.cpt(dbn_id)[:]
284
285  return bn
286
287
288def plotFollowUnrolled(lovars, dbn, T, evs):
289  """
290  plot the dynamic evolution of a list of vars with a dBN
291
292  :param lovars: list of variables to follow
293  :param dbn: the unrolled dbn
294  :param T: the time range
295  :param evs: observations
296  """
297  ie = gum.LazyPropagation(dbn)
298  ie.setEvidence(evs)
299  ie.makeInference()
300
301  x = np.arange(T)
302
303  for var in lovars:
304    v0 = dbn.variableFromName(var + "0")
305    lpots = []
306    for i in range(v0.domainSize()):
307      serie = []
308      for t in range(T):
309        serie.append(ie.posterior(dbn.idFromName(var + str(t)))[i])
310      lpots.append(serie)
311
312    _, ax = plt.subplots()
313    plt.ylim(top=1, bottom=0)
314    ax.xaxis.grid()
315    plt.title(f"Following variable {var}", fontsize=20)
316    plt.xlabel('time')
317
318    stack = ax.stackplot(x, lpots)
319
320    proxy_rects = [Rectangle((0, 0), 1, 1, fc=pc.get_facecolor()[0])
321                   for pc in stack]
322    labels = [v0.label(i) for i in range(v0.domainSize())]
323    plt.legend(proxy_rects, labels, loc='center left',
324               bbox_to_anchor=(1, 0.5), ncol=1, fancybox=True, shadow=True
325               )
326
327    plt.show()
328
329
330def plotFollow(lovars, twoTdbn, T, evs):
331  """
332  plots modifications of variables in a 2TDN knowing the size of the time window (T) and the evidence on the sequence.
333
334  :param lovars: list of variables to follow
335  :param twoTdbn: the two-timeslice dbn
336  :param T: the time range
337  :param evs: observations
338  """
339  plotFollowUnrolled(lovars, unroll2TBN(twoTdbn, T), T, evs)
340