1""" 2The purpose of this module is to provide basic tools for dealing with dynamic Bayesian Network (and inference) : modeling, visualisation, inference. 3""" 4 5# (c) Copyright by Pierre-Henri Wuillemin, UPMC, 2017 6# (pierre-henri.wuillemin@lip6.fr) 7 8# Permission to use, copy, modify, and distribute this 9# software and its documentation for any purpose and 10# without fee or royalty is hereby granted, provided 11# that the above copyright notice appear in all copies 12# and that both that copyright notice and this permission 13# notice appear in supporting documentation or portions 14# thereof, including modifications, that you make. 15 16# THE AUTHOR P.H. WUILLEMIN DISCLAIMS ALL WARRANTIES 17# WITH REGARD TO THIS SOFTWARE, INCLUDING ALL IMPLIED 18# WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT 19# SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT 20# OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER 21# RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER 22# IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 23# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE 24# OR PERFORMANCE OF THIS SOFTWARE! 25 26import numpy as np 27import pydotplus as dot 28 29import matplotlib.pyplot as plt 30from matplotlib.patches import Rectangle 31 32import pyAgrum as gum 33 34noTimeCluster = "void" 35 36 37def _splitName(name): 38 """ 39 By convention, name of dynamic variable "X" in dBN may be 40 - "X0" for timeslice 0 both in unrolled BN and in 2TBN 41 - "Xt" for timeslice t in a 2TBN 42 - "X"+str(i) for timelice i with integer i>0 in unrolled BN 43 - other naes are not in a timeslice 44 @argument name : str (name of the dynamic variable) 45 @return static_name,timeslice with timeslice =noTimeCluster,"t" or str(i) 46 """ 47 if name[-1] == "t": 48 return name[:-1], "t" 49 i = len(name) - 1 50 if not name[i].isdigit(): 51 return name, noTimeCluster 52 53 while name[i].isdigit(): 54 if i == 0: 55 return name, noTimeCluster 56 i -= 1 57 58 return name[:i + 1], name[i + 1:] 59 60 61def _isInFirstTimeSlice(name): 62 """ 63 @return true if there is a 0 at the end of name 64 """ 65 return name[-1] == "0" 66 67 68def _isInSecondTimeSlice(name): 69 """ 70 @return true if there is a t at the end of name 71 """ 72 return name[-1] == "t" 73 74 75def _isInNoTimeSlice(name): 76 return name[-1] not in ["0", "t"] 77 78 79def realNameFrom2TBNname(name, ts): 80 """ 81 @return dynamic name from static name and timeslice (no check) 82 """ 83 return f"{name[:-1]}{ts}" if not _isInNoTimeSlice(name) else name 84 85 86def getTimeSlicesRange(dbn): 87 """ 88 get the range and (name,radical) of each variables 89 90 :param dbn: a 2TBN or an unrolled BN 91 :return: all the timeslice of a dbn 92 93 e.g. 94 ['0','t'] for a classic 2TBN 95 range(T) for a classic unrolled BN 96 """ 97 timeslices = {} 98 99 for i in dbn.nodes(): 100 n = dbn.variable(i).name() 101 label, ts = _splitName(n) 102 if ts in timeslices: 103 timeslices[ts].append((n, label)) 104 else: 105 timeslices[ts] = [(n, label)] 106 107 return timeslices 108 109 110def is2TBN(bn): 111 """ 112 Check if bn is a 2 TimeSlice Bayesian network 113 114 :param bn: the Bayesian network 115 :return: True if the BN is syntaxically correct to be a 2TBN 116 """ 117 ts = getTimeSlicesRange(bn) 118 if not set(ts.keys()) <= {noTimeCluster, "0", "t"}: 119 return False, "Some variables are not correctly suffixed." 120 121 domainSizes = dict() 122 for name, radical in ts['t']: 123 domainSizes[radical] = bn.variable(name).domainSize() 124 125 res = "" 126 for name, radical in ts['0']: 127 if radical in domainSizes: 128 if domainSizes[radical] != bn.variable(name).domainSize(): 129 res = f"\n - for variables {name}/{radical}t" 130 131 if res != "": 132 return False, "Domain size mismatch : " + res 133 134 return True, "" 135 136 137def _TimeSlicesToDot(dbn): 138 """ 139 Try to correctly represent dBN and 2TBN in dot format 140 """ 141 timeslices = getTimeSlicesRange(dbn) 142 143 # dynamic member makes pylink unhappy 144 # pylint: disable=no-member 145 g = dot.Dot(graph_type='digraph') 146 g.set_rankdir("TD") 147 g.set_splines("ortho") 148 149 for k in sorted(timeslices.keys(), key=lambda x: -1 if x == noTimeCluster else 1e8 if x == 't' else int(x)): 150 if k != noTimeCluster: 151 cluster = dot.Cluster(k, label=f"Time slice {k}", bgcolor="#DDDDDD", rankdir="TD") 152 g.add_subgraph(cluster) 153 else: 154 cluster = g # small trick to add in graph variable in no timeslice 155 for (n, label) in sorted(timeslices[k]): 156 cluster.add_node(dot.Node('"' + n + '"', label='"' + label + '"', style='filled', 157 color='#000000', fillcolor='white' 158 ) 159 ) 160 161 for tail, head in dbn.arcs(): 162 g.add_edge(dot.Edge('"' + dbn.variable(tail).name() + '"', 163 '"' + dbn.variable(head).name() + '"', 164 constraint=False, color="blue" 165 ) 166 ) 167 168 for k in sorted(timeslices.keys(), key=lambda x: -1 if x == noTimeCluster else 1e8 if x == 't' else int(x)): 169 if k != noTimeCluster: 170 prec = None 171 for (n, label) in sorted(timeslices[k]): 172 if prec is not None: 173 g.add_edge(dot.Edge('"' + prec + '"', 174 '"' + n + '"', 175 style="invis" 176 ) 177 ) 178 prec = n 179 180 return g 181 182 183def showTimeSlices(dbn, size=None): 184 """ 185 Try to correctly display dBN and 2TBN 186 187 :param dbn: the dynamic BN 188 :param size: size of the fig 189 """ 190 191 # jupyter notebooks is optional 192 # pylint: disable=import-outside-toplevel 193 from pyAgrum.lib.notebook import showGraph 194 if size is None: 195 size = gum.config["dynamicBN", "default_graph_size"] 196 197 showGraph(_TimeSlicesToDot(dbn), size) 198 199 200def getTimeSlices(dbn, size=None): 201 """ 202 Try to correctly represent dBN and 2TBN as an HTML string 203 204 :param dbn: the dynamic BN 205 :param size: size of the fig 206 """ 207 # jupyter notebooks is optional 208 # pylint: disable=import-outside-toplevel 209 from pyAgrum.lib.notebook import getGraph 210 if size is None: 211 size = gum.config["dynamicBN", "default_graph_size"] 212 213 return getGraph(_TimeSlicesToDot(dbn), size) 214 215 216def unroll2TBN(dbn, nbr): 217 """ 218 unroll a 2TBN given the nbr of timeslices 219 220 :param dbn: the dBN 221 :param nbr: the number of timeslice 222 223 :return: unrolled BN from a 2TBN and the nbr of timeslices 224 """ 225 ts = getTimeSlicesRange(dbn) 226 if not {noTimeCluster, "0", "t"}.issuperset(ts.keys()) and {"0", "t"}.issubset(ts.keys()): 227 raise TypeError("unroll2TBN needs a 2-TimeSlice BN") 228 229 bn = gum.BayesNet() 230 231 # variable creation 232 for dbn_id in dbn.nodes(): 233 name = dbn.variable(dbn_id).name() 234 if _isInNoTimeSlice(name): 235 bn.add(dbn.variable(dbn_id)) 236 elif _isInFirstTimeSlice(name): 237 # create a clone of the variable in the new bn 238 bn.add(dbn.variable(dbn_id)) 239 else: 240 for ts in range(1, nbr): 241 # create a clone of the variable in the new bn 242 nid = bn.add(dbn.variable(dbn_id)) 243 bn.changeVariableName(nid, realNameFrom2TBNname( 244 name, ts 245 ) 246 ) # create the true name 247 248 # add parents 249 # the main pb : to have the same order for parents w.r.t the order in 2TBN 250 for dbn_id in dbn.nodes(): 251 name = dbn.variable(dbn_id).name() 252 # right order for parents 253 lvarnames = dbn.cpt(dbn_id).var_names 254 lvarnames.pop() 255 lvarnames.reverse() 256 257 for name_parent in lvarnames: 258 if not _isInSecondTimeSlice(name): 259 if not _isInSecondTimeSlice(name_parent): 260 bn.addArc(bn.idFromName(name_parent), bn.idFromName(name)) 261 else: 262 if _isInFirstTimeSlice(name): 263 raise TypeError("An arc from timeslice t to timeslice is impossible in dBN") 264 for ts in range(1, nbr): 265 new_name_parent = realNameFrom2TBNname(name_parent, ts) # current TimeSlice 266 bn.addArc(bn.idFromName(new_name_parent), bn.idFromName(name)) 267 else: 268 for ts in range(1, nbr): 269 if _isInFirstTimeSlice(name_parent): 270 new_name_parent = realNameFrom2TBNname(name_parent, ts - 1) # last TimeSlice 271 else: 272 new_name_parent = realNameFrom2TBNname(name_parent, ts) # current TimeSlice 273 new_name = realNameFrom2TBNname(name, ts) # necessary current TimeSlice 274 bn.addArc(bn.idFromName(new_name_parent), bn.idFromName(new_name)) 275 276 # potential creation 277 for dbn_id in dbn.nodes(): 278 name = dbn.variable(dbn_id).name() 279 if not _isInSecondTimeSlice(name): 280 bn.cpt(bn.idFromName(name))[:] = dbn.cpt(dbn_id)[:] 281 else: 282 for ts in range(1, nbr): 283 bn.cpt(bn.idFromName(realNameFrom2TBNname(name, ts)))[:] = dbn.cpt(dbn_id)[:] 284 285 return bn 286 287 288def plotFollowUnrolled(lovars, dbn, T, evs): 289 """ 290 plot the dynamic evolution of a list of vars with a dBN 291 292 :param lovars: list of variables to follow 293 :param dbn: the unrolled dbn 294 :param T: the time range 295 :param evs: observations 296 """ 297 ie = gum.LazyPropagation(dbn) 298 ie.setEvidence(evs) 299 ie.makeInference() 300 301 x = np.arange(T) 302 303 for var in lovars: 304 v0 = dbn.variableFromName(var + "0") 305 lpots = [] 306 for i in range(v0.domainSize()): 307 serie = [] 308 for t in range(T): 309 serie.append(ie.posterior(dbn.idFromName(var + str(t)))[i]) 310 lpots.append(serie) 311 312 _, ax = plt.subplots() 313 plt.ylim(top=1, bottom=0) 314 ax.xaxis.grid() 315 plt.title(f"Following variable {var}", fontsize=20) 316 plt.xlabel('time') 317 318 stack = ax.stackplot(x, lpots) 319 320 proxy_rects = [Rectangle((0, 0), 1, 1, fc=pc.get_facecolor()[0]) 321 for pc in stack] 322 labels = [v0.label(i) for i in range(v0.domainSize())] 323 plt.legend(proxy_rects, labels, loc='center left', 324 bbox_to_anchor=(1, 0.5), ncol=1, fancybox=True, shadow=True 325 ) 326 327 plt.show() 328 329 330def plotFollow(lovars, twoTdbn, T, evs): 331 """ 332 plots modifications of variables in a 2TDN knowing the size of the time window (T) and the evidence on the sequence. 333 334 :param lovars: list of variables to follow 335 :param twoTdbn: the two-timeslice dbn 336 :param T: the time range 337 :param evs: observations 338 """ 339 plotFollowUnrolled(lovars, unroll2TBN(twoTdbn, T), T, evs) 340