1from functools import partial 2from warnings import warn 3 4import autograd.numpy as np 5import autograd.numpy.random as npr 6from autograd.scipy.misc import logsumexp 7from autograd.scipy.stats import dirichlet 8from autograd import hessian 9 10from ssm.util import one_hot, logistic, relu, rle, \ 11 fit_multiclass_logistic_regression, \ 12 fit_negative_binomial_integer_r, ensure_args_are_lists 13from ssm.stats import multivariate_normal_logpdf 14from ssm.optimizers import adam, bfgs, lbfgs, rmsprop, sgd 15 16 17class Transitions(object): 18 def __init__(self, K, D, M=0): 19 self.K, self.D, self.M = K, D, M 20 self.type_name = self.__class__.__name__ 21 22 @property 23 def params(self): 24 raise NotImplementedError 25 26 @params.setter 27 def params(self, value): 28 raise NotImplementedError 29 30 @ensure_args_are_lists 31 def initialize(self, datas, inputs=None, masks=None, tags=None): 32 pass 33 34 def permute(self, perm): 35 pass 36 37 def log_prior(self): 38 return 0 39 40 def log_transition_matrices(self, data, input, mask, tag): 41 raise NotImplementedError 42 43 def m_step(self, expectations, datas, inputs, masks, tags, 44 optimizer="lbfgs", num_iters=100, **kwargs): 45 """ 46 If M-step cannot be done in closed form for the transitions, default to BFGS. 47 """ 48 optimizer = dict(sgd=sgd, adam=adam, rmsprop=rmsprop, bfgs=bfgs, lbfgs=lbfgs)[optimizer] 49 50 # Maximize the expected log joint 51 def _expected_log_joint(expectations): 52 elbo = self.log_prior() 53 for data, input, mask, tag, (expected_states, expected_joints, _) \ 54 in zip(datas, inputs, masks, tags, expectations): 55 log_Ps = self.log_transition_matrices(data, input, mask, tag) 56 elbo += np.sum(expected_joints * log_Ps) 57 return elbo 58 59 # Normalize and negate for minimization 60 T = sum([data.shape[0] for data in datas]) 61 def _objective(params, itr): 62 self.params = params 63 obj = _expected_log_joint(expectations) 64 return -obj / T 65 66 # Call the optimizer. Persist state (e.g. SGD momentum) across calls to m_step. 67 optimizer_state = self.optimizer_state if hasattr(self, "optimizer_state") else None 68 self.params, self.optimizer_state = \ 69 optimizer(_objective, self.params, num_iters=num_iters, 70 state=optimizer_state, full_output=True, **kwargs) 71 72 def hessian_expected_log_trans_prob(self, data, input, mask, tag, expected_joints): 73 # Return (T-1, D, D) array of blocks for the diagonal of the Hessian 74 warn("Analytical Hessian is not implemented for this transition class. \ 75 Optimization via Laplace-EM may be slow. Consider using an \ 76 alternative posterior and inference method.") 77 T, D = data.shape 78 obj = lambda x, E_zzp1: np.sum(E_zzp1 * self.log_transition_matrices(x, input, mask, tag)) 79 hess = hessian(obj) 80 terms = np.array([hess(x[None,:], Ezzp1) for x, Ezzp1 in zip(data, expected_joints)]) 81 return terms 82 83class StationaryTransitions(Transitions): 84 """ 85 Standard Hidden Markov Model with fixed initial distribution and transition matrix. 86 """ 87 def __init__(self, K, D, M=0): 88 super(StationaryTransitions, self).__init__(K, D, M=M) 89 Ps = .95 * np.eye(K) + .05 * npr.rand(K, K) 90 Ps /= Ps.sum(axis=1, keepdims=True) 91 self.log_Ps = np.log(Ps) 92 93 @property 94 def params(self): 95 return (self.log_Ps,) 96 97 @params.setter 98 def params(self, value): 99 self.log_Ps = value[0] 100 101 def permute(self, perm): 102 """ 103 Permute the discrete latent states. 104 """ 105 self.log_Ps = self.log_Ps[np.ix_(perm, perm)] 106 107 @property 108 def transition_matrix(self): 109 return np.exp(self.log_Ps - logsumexp(self.log_Ps, axis=1, keepdims=True)) 110 111 def log_transition_matrices(self, data, input, mask, tag): 112 T = data.shape[0] 113 log_Ps = self.log_Ps - logsumexp(self.log_Ps, axis=1, keepdims=True) 114 # return np.tile(log_Ps[None, :, :], (T-1, 1, 1)) 115 return log_Ps[None, :, :] 116 117 def m_step(self, expectations, datas, inputs, masks, tags, **kwargs): 118 P = sum([np.sum(Ezzp1, axis=0) for _, Ezzp1, _ in expectations]) + 1e-16 119 P /= P.sum(axis=-1, keepdims=True) 120 self.log_Ps = np.log(P) 121 122 def hessian_expected_log_trans_prob(self, data, input, mask, tag, expected_joints): 123 # Return (T-1, D, D) array of blocks for the diagonal of the Hessian 124 T, D = data.shape 125 return np.zeros((T-1, D, D)) 126 127class StickyTransitions(StationaryTransitions): 128 """ 129 Upweight the self transition prior. 130 131 pi_k ~ Dir(alpha + kappa * e_k) 132 """ 133 def __init__(self, K, D, M=0, alpha=1, kappa=100): 134 super(StickyTransitions, self).__init__(K, D, M=M) 135 self.alpha = alpha 136 self.kappa = kappa 137 138 def log_prior(self): 139 K = self.K 140 Ps = np.exp(self.log_Ps - logsumexp(self.log_Ps, axis=1, keepdims=True)) 141 142 lp = 0 143 for k in range(K): 144 alpha = self.alpha * np.ones(K) + self.kappa * (np.arange(K) == k) 145 lp += dirichlet.logpdf(Ps[k], alpha) 146 return lp 147 148 def m_step(self, expectations, datas, inputs, masks, tags, **kwargs): 149 expected_joints = sum([np.sum(Ezzp1, axis=0) for _, Ezzp1, _ in expectations]) + 1e-8 150 expected_joints += self.kappa * np.eye(self.K) 151 P = expected_joints / expected_joints.sum(axis=1, keepdims=True) 152 self.log_Ps = np.log(P) 153 154 def hessian_expected_log_trans_prob(self, data, input, mask, tag, expected_joints): 155 # Return (T-1, D, D) array of blocks for the diagonal of the Hessian 156 T, D = data.shape 157 return np.zeros((T-1, D, D)) 158 159class InputDrivenTransitions(StickyTransitions): 160 """ 161 Hidden Markov Model whose transition probabilities are 162 determined by a generalized linear model applied to the 163 exogenous input. 164 """ 165 def __init__(self, K, D, M, alpha=1, kappa=0, l2_penalty=0.0): 166 super(InputDrivenTransitions, self).__init__(K, D, M=M, alpha=alpha, kappa=kappa) 167 168 # Parameters linking input to state distribution 169 self.Ws = npr.randn(K, M) 170 171 # Regularization of Ws 172 self.l2_penalty = l2_penalty 173 174 @property 175 def params(self): 176 return self.log_Ps, self.Ws 177 178 @params.setter 179 def params(self, value): 180 self.log_Ps, self.Ws = value 181 182 def permute(self, perm): 183 """ 184 Permute the discrete latent states. 185 """ 186 self.log_Ps = self.log_Ps[np.ix_(perm, perm)] 187 self.Ws = self.Ws[perm] 188 189 def log_prior(self): 190 lp = super(InputDrivenTransitions, self).log_prior() 191 lp = lp + np.sum(-0.5 * self.l2_penalty * self.Ws**2) 192 return lp 193 194 def log_transition_matrices(self, data, input, mask, tag): 195 T = data.shape[0] 196 assert input.shape[0] == T 197 # Previous state effect 198 log_Ps = np.tile(self.log_Ps[None, :, :], (T-1, 1, 1)) 199 # Input effect 200 log_Ps = log_Ps + np.dot(input[1:], self.Ws.T)[:, None, :] 201 return log_Ps - logsumexp(log_Ps, axis=2, keepdims=True) 202 203 def m_step(self, expectations, datas, inputs, masks, tags, **kwargs): 204 Transitions.m_step(self, expectations, datas, inputs, masks, tags, **kwargs) 205 206 def hessian_expected_log_trans_prob(self, data, input, mask, tag, expected_joints): 207 # Return (T-1, D, D) array of blocks for the diagonal of the Hessian 208 T, D = data.shape 209 return np.zeros((T-1, D, D)) 210 211class RecurrentTransitions(InputDrivenTransitions): 212 """ 213 Generalization of the input driven HMM in which the observations serve as future inputs 214 """ 215 def __init__(self, K, D, M=0, alpha=1, kappa=0): 216 super(RecurrentTransitions, self).__init__(K, D, M, alpha=alpha, kappa=kappa) 217 218 # Parameters linking past observations to state distribution 219 self.Rs = np.zeros((K, D)) 220 221 @property 222 def params(self): 223 return super(RecurrentTransitions, self).params + (self.Rs,) 224 225 @params.setter 226 def params(self, value): 227 self.Rs = value[-1] 228 super(RecurrentTransitions, self.__class__).params.fset(self, value[:-1]) 229 230 def permute(self, perm): 231 """ 232 Permute the discrete latent states. 233 """ 234 super(RecurrentTransitions, self).permute(perm) 235 self.Rs = self.Rs[perm] 236 237 def log_transition_matrices(self, data, input, mask, tag): 238 T, D = data.shape 239 # Previous state effect 240 log_Ps = np.tile(self.log_Ps[None, :, :], (T-1, 1, 1)) 241 # Input effect 242 log_Ps = log_Ps + np.dot(input[1:], self.Ws.T)[:, None, :] 243 # Past observations effect 244 log_Ps = log_Ps + np.dot(data[:-1], self.Rs.T)[:, None, :] 245 return log_Ps - logsumexp(log_Ps, axis=2, keepdims=True) 246 247 def m_step(self, expectations, datas, inputs, masks, tags, **kwargs): 248 Transitions.m_step(self, expectations, datas, inputs, masks, tags, **kwargs) 249 250 def hessian_expected_log_trans_prob(self, data, input, mask, tag, expected_joints): 251 # Return (T-1, D, D) array of blocks for the diagonal of the Hessian 252 T, D = data.shape 253 hess = np.zeros((T-1,D,D)) 254 vtildes = np.exp(self.log_transition_matrices(data, input, mask, tag)) # normalized probabilities 255 Ez = np.sum(expected_joints, axis=2) # marginal over z from T=1 to T-1 256 for k in range(self.K): 257 vtilde = vtildes[:,k,:] # normalized probabilities given state k 258 Rv = vtilde@self.Rs 259 hess += Ez[:,k][:,None,None] * \ 260 ( np.einsum('tn, ni, nj ->tij', -vtilde, self.Rs, self.Rs) \ 261 + np.einsum('ti, tj -> tij', Rv, Rv)) 262 return hess 263 264class RecurrentOnlyTransitions(Transitions): 265 """ 266 Only allow the past observations and inputs to influence the 267 next state. Get rid of the transition matrix and replace it 268 with a constant bias r. 269 """ 270 def __init__(self, K, D, M=0): 271 super(RecurrentOnlyTransitions, self).__init__(K, D, M) 272 273 # Parameters linking past observations to state distribution 274 self.Ws = npr.randn(K, M) 275 self.Rs = npr.randn(K, D) 276 self.r = npr.randn(K) 277 278 @property 279 def params(self): 280 return self.Ws, self.Rs, self.r 281 282 @params.setter 283 def params(self, value): 284 self.Ws, self.Rs, self.r = value 285 286 def permute(self, perm): 287 """ 288 Permute the discrete latent states. 289 """ 290 self.Ws = self.Ws[perm] 291 self.Rs = self.Rs[perm] 292 self.r = self.r[perm] 293 294 def log_transition_matrices(self, data, input, mask, tag): 295 T, D = data.shape 296 log_Ps = np.dot(input[1:], self.Ws.T)[:, None, :] # inputs 297 log_Ps = log_Ps + np.dot(data[:-1], self.Rs.T)[:, None, :] # past observations 298 log_Ps = log_Ps + self.r # bias 299 log_Ps = np.tile(log_Ps, (1, self.K, 1)) # expand 300 return log_Ps - logsumexp(log_Ps, axis=2, keepdims=True) # normalize 301 302 def m_step(self, expectations, datas, inputs, masks, tags, **kwargs): 303 Transitions.m_step(self, expectations, datas, inputs, masks, tags, **kwargs) 304 305 def hessian_expected_log_trans_prob(self, data, input, mask, tag, expected_joints): 306 # Return (T-1, D, D) array of blocks for the diagonal of the Hessian 307 T, D = data.shape 308 v = np.dot(input[1:], self.Ws.T) + np.dot(data[:-1], self.Rs.T) + self.r 309 shifted_exp = np.exp(v - np.max(v,axis=1,keepdims=True)) 310 vtilde = shifted_exp / np.sum(shifted_exp,axis=1,keepdims=True) # normalized probabilities 311 Rv = vtilde@self.Rs 312 return np.einsum('tn, ni, nj ->tij', -vtilde, self.Rs, self.Rs) \ 313 + np.einsum('ti, tj -> tij', Rv, Rv) 314 315class RBFRecurrentTransitions(InputDrivenTransitions): 316 """ 317 Recurrent transitions with radial basis functions for parameterizing 318 the next state probability given current continuous data. We have, 319 320 p(z_{t+1} = k | z_t, x_t) 321 \propto N(x_t | \mu_k, \Sigma_k) \times \pi_{z_t, z_{t+1}) 322 323 where {\mu_k, \Sigma_k, \pi_k}_{k=1}^K are learned parameters. 324 Equivalently, 325 326 log p(z_{t+1} = k | z_t, x_t) 327 = log N(x_t | \mu_k, \Sigma_k) + log \pi_{z_t, z_{t+1}) + const 328 = -D/2 log(2\pi) -1/2 log |Sigma_k| 329 -1/2 (x - \mu_k)^T \Sigma_k^{-1} (x-\mu_k) 330 + log \pi{z_t, z_{t+1}} 331 332 The difference between this and the recurrent model above is that the 333 log transition matrices are quadratic functions of x rather than linear. 334 335 While we're at it, there's no harm in adding a linear term to the log 336 transition matrices to capture input dependencies. 337 """ 338 def __init__(self, K, D, M=0, alpha=1, kappa=0): 339 super(RBFRecurrentTransitions, self).__init__(K, D, M=M, alpha=alpha, kappa=kappa) 340 341 # RBF parameters 342 self.mus = npr.randn(K, D) 343 self._sqrt_Sigmas = npr.randn(K, D, D) 344 345 @property 346 def params(self): 347 return self.log_Ps, self.mus, self._sqrt_Sigmas, self.Ws 348 349 @params.setter 350 def params(self, value): 351 self.log_Ps, self.mus, self._sqrt_Sigmas, self.Ws = value 352 353 @property 354 def Sigmas(self): 355 return np.matmul(self._sqrt_Sigmas, np.swapaxes(self._sqrt_Sigmas, -1, -2)) 356 357 @ensure_args_are_lists 358 def initialize(self, datas, inputs=None, masks=None, tags=None): 359 # Fit a GMM to the data to set the means and covariances 360 from sklearn.mixture import GaussianMixture 361 gmm = GaussianMixture(self.K, covariance_type="full") 362 gmm.fit(np.vstack(datas)) 363 self.mus = gmm.means_ 364 self._sqrt_Sigmas = np.linalg.cholesky(gmm.covariances_) 365 366 def permute(self, perm): 367 """ 368 Permute the discrete latent states. 369 """ 370 self.log_Ps = self.log_Ps[np.ix_(perm, perm)] 371 self.mus = self.mus[perm] 372 self.sqrt_Sigmas = self.sqrt_Sigmas[perm] 373 self.Ws = self.Ws[perm] 374 375 def log_transition_matrices(self, data, input, mask, tag): 376 assert np.all(mask), "Recurrent models require that all data are present." 377 378 T = data.shape[0] 379 assert input.shape[0] == T 380 K, D = self.K, self.D 381 382 # Previous state effect 383 log_Ps = np.tile(self.log_Ps[None, :, :], (T-1, 1, 1)) 384 385 # RBF recurrent function 386 rbf = multivariate_normal_logpdf(data[:-1, None, :], self.mus, self.Sigmas) 387 log_Ps = log_Ps + rbf[:, None, :] 388 389 # Input effect 390 log_Ps = log_Ps + np.dot(input[1:], self.Ws.T)[:, None, :] 391 return log_Ps - logsumexp(log_Ps, axis=2, keepdims=True) 392 393 def m_step(self, expectations, datas, inputs, masks, tags, **kwargs): 394 Transitions.m_step(self, expectations, datas, inputs, masks, tags, **kwargs) 395 396 397# Allow general nonlinear emission models with neural networks 398class NeuralNetworkRecurrentTransitions(Transitions): 399 def __init__(self, K, D, M=0, hidden_layer_sizes=(50,), nonlinearity="relu"): 400 super(NeuralNetworkRecurrentTransitions, self).__init__(K, D, M=M) 401 402 # Baseline transition probabilities 403 Ps = .95 * np.eye(K) + .05 * npr.rand(K, K) 404 Ps /= Ps.sum(axis=1, keepdims=True) 405 self.log_Ps = np.log(Ps) 406 407 # Initialize the NN weights 408 layer_sizes = (D + M,) + hidden_layer_sizes + (K,) 409 self.weights = [npr.randn(m, n) for m, n in zip(layer_sizes[:-1], layer_sizes[1:])] 410 self.biases = [npr.randn(n) for n in layer_sizes[1:]] 411 412 nonlinearities = dict( 413 relu=relu, 414 tanh=np.tanh, 415 sigmoid=logistic) 416 self.nonlinearity = nonlinearities[nonlinearity] 417 418 @property 419 def params(self): 420 return self.log_Ps, self.weights, self.biases 421 422 @params.setter 423 def params(self, value): 424 self.log_Ps, self.weights, self.biases = value 425 426 def permute(self, perm): 427 self.log_Ps = self.log_Ps[np.ix_(perm, perm)] 428 self.weights[-1] = self.weights[-1][:,perm] 429 self.biases[-1] = self.biases[-1][perm] 430 431 def log_transition_matrices(self, data, input, mask, tag): 432 # Pass the data and inputs through the neural network 433 x = np.hstack((data[:-1], input[1:])) 434 for W, b in zip(self.weights, self.biases): 435 y = np.dot(x, W) + b 436 x = self.nonlinearity(y) 437 438 # Add the baseline transition biases 439 log_Ps = self.log_Ps[None, :, :] + y[:, None, :] 440 441 # Normalize 442 return log_Ps - logsumexp(log_Ps, axis=2, keepdims=True) 443 444 def m_step(self, expectations, datas, inputs, masks, tags, optimizer="adam", num_iters=100, **kwargs): 445 # Default to adam instead of bfgs for the neural network model. 446 Transitions.m_step(self, expectations, datas, inputs, masks, tags, 447 optimizer=optimizer, num_iters=num_iters, **kwargs) 448 449 450class NegativeBinomialSemiMarkovTransitions(Transitions): 451 """ 452 Semi-Markov transition model with negative binomial (NB) distributed 453 state durations, as compared to the geometric state durations in the 454 standard Markov model. The negative binomial has higher variance than 455 the geometric, but its mode can be greater than 1. 456 457 The NB(r, p) distribution, with r a positive integer and p a probability 458 in [0, 1], is this distribution over number of heads before seeing 459 r tails where the probability of heads is p. The number of heads 460 between each tails is an independent geometric random variable. Thus, 461 the total number of heads is the sum of r independent and identically 462 distributed geometric random variables. 463 464 We can "embed" the semi-Markov model with negative binomial durations 465 in the standard Markov model by expanding the state space. Map each 466 discrete state k to r new states: (k,1), (k,2), ..., (k,r_k), 467 for k in 1, ..., K. The total number of states is \sum_k r_k, 468 where state k has a NB(r_k, p_k) duration distribution. 469 470 The transition probabilities are as follows. The probability of staying 471 within the same "super state" are: 472 473 p(z_{t+1} = (k,i) | z_t = (k,i)) = p_k 474 475 and for 0 <= j <= r_k - i 476 477 p(z_{t+1} = (k,i+j) | z_t = (k,i)) = (1-p_k)^{j-i} p_k 478 479 The probability of flipping (r_k - i + 1) tails in a row in state k; 480 i.e. the probability of exiting super state k, is (1-p_k)^{r_k-i+1}. 481 Thus, the probability of transitioning to a new super state is: 482 483 p(z_{t+1} = (j,1) | z_t = (k,i)) = (1-p_k)^{r_k-i+1} * P[k, j] 484 485 where P[k, j] is a transition matrix with zero diagonal. 486 487 As a sanity check, note that the sum of probabilities is indeed 1: 488 489 \sum_{j=i}^{r_k} p(z_{t+1} = (k,j) | z_t = (k,i)) 490 + \sum_{m \neq k} p(z_{t+1} = (m, 1) | z_t = (k, i)) 491 492 = \sum_{j=0}^{r_k-i} (1-p_k)^j p_k + \sum_{m \neq k} (1-p_k)^{r_k-i+1} * P[k, j] 493 494 = p_k (1-(1-p_k)^{r_k-i+1}) / (1-(1-p_k)) + (1-p_k)^{r_k-i+1} 495 496 = 1 - (1-p_k)^{r_k-i+1} + (1 - p_k)^{r_k-i+1} 497 498 = 1. 499 500 where we used the geometric series and the fact that \sum_{j != k} P[k, j] = 1. 501 """ 502 def __init__(self, K, D, M=0, r_min=1, r_max=20): 503 assert K > 1, "Explicit duration models only work if num states > 1." 504 super(NegativeBinomialSemiMarkovTransitions, self).__init__(K, D, M=M) 505 506 # Initialize the super state transition probabilities 507 self.Ps = npr.rand(K, K) 508 np.fill_diagonal(self.Ps, 0) 509 self.Ps /= self.Ps.sum(axis=1, keepdims=True) 510 511 # Initialize the negative binomial duration probabilities 512 self.r_min, self.r_max = r_min, r_max 513 self.rs = npr.randint(r_min, r_max + 1, size=K) 514 # self.rs = np.ones(K, dtype=int) 515 # self.ps = npr.rand(K) 516 self.ps = 0.5 * np.ones(K) 517 518 # Initialize the transition matrix 519 self._transition_matrix = None 520 521 @property 522 def params(self): 523 return (self.Ps, self.rs, self.ps) 524 525 @params.setter 526 def params(self, value): 527 Ps, rs, ps = value 528 assert Ps.shape == (self.K, self.K) 529 assert np.allclose(np.diag(Ps), 0) 530 assert np.allclose(Ps.sum(1), 1) 531 assert rs.shape == (self.K) 532 assert rs.dtype == int 533 assert np.all(rs > 0) 534 assert ps.shape == (self.K) 535 assert np.all(ps > 0) 536 assert np.all(ps < 1) 537 self.Ps, self.rs, self.ps = Ps, rs, ps 538 539 # Reset the transition matrix 540 self._transition_matrix = None 541 542 def permute(self, perm): 543 """ 544 Permute the discrete latent states. 545 """ 546 self.Ps = self.Ps[np.ix_(perm, perm)] 547 self.rs = self.rs[perm] 548 self.ps = self.ps[perm] 549 550 # Reset the transition matrix 551 self._transition_matrix = None 552 553 @property 554 def total_num_states(self): 555 return np.sum(self.rs) 556 557 @property 558 def state_map(self): 559 return np.repeat(np.arange(self.K), self.rs) 560 561 @property 562 def transition_matrix(self): 563 if self._transition_matrix is not None: 564 return self._transition_matrix 565 566 As, rs, ps = self.Ps, self.rs, self.ps 567 568 # Fill in the transition matrix one block at a time 569 K_total = self.total_num_states 570 P = np.zeros((K_total, K_total)) 571 starts = np.concatenate(([0], np.cumsum(rs)[:-1])) 572 ends = np.cumsum(rs) 573 for (i, j), Aij in np.ndenumerate(As): 574 block = P[starts[i]:ends[i], starts[j]:ends[j]] 575 576 # Diagonal blocks (stay in sub-state or advance to next sub-state) 577 if i == j: 578 for k in range(rs[i]): 579 # p(z_{t+1} = (.,i+k) | z_t = (.,i)) = (1-p)^k p 580 # for 0 <= k <= r - i 581 block += (1 - ps[i])**k * ps[i] * np.diag(np.ones(rs[i]-k), k=k) 582 583 # Off-diagonal blocks (exit to a new super state) 584 else: 585 # p(z_{t+1} = (j,1) | z_t = (k,i)) = (1-p_k)^{r_k-i+1} * A[k, j] 586 block[:,0] = (1-ps[i]) ** np.arange(rs[i], 0, -1) * Aij 587 588 assert np.allclose(P.sum(1),1) 589 assert (0 <= P).all() and (P <= 1.).all() 590 591 # Cache the transition matrix 592 self._transition_matrix = P 593 594 return P 595 596 def log_transition_matrices(self, data, input, mask, tag): 597 T = data.shape[0] 598 P = self.transition_matrix 599 return np.tile(np.log(P)[None, :, :], (T-1, 1, 1)) 600 601 def m_step(self, expectations, datas, inputs, masks, tags, samples, **kwargs): 602 # Update the transition matrix between super states 603 P = sum([np.sum(Ezzp1, axis=0) for _, Ezzp1, _ in expectations]) + 1e-16 604 np.fill_diagonal(P, 0) 605 P /= P.sum(axis=-1, keepdims=True) 606 self.Ps = P 607 608 # Fit negative binomial models for each duration based on sampled states 609 states, durations = map(np.concatenate, zip(*[rle(z_smpl) for z_smpl in samples])) 610 for k in range(self.K): 611 self.rs[k], self.ps[k] = \ 612 fit_negative_binomial_integer_r(durations[states == k], self.r_min, self.r_max) 613 614 # Reset the transition matrix 615 self._transition_matrix = None 616