1from functools import partial
2from warnings import warn
3
4import autograd.numpy as np
5import autograd.numpy.random as npr
6from autograd.scipy.misc import logsumexp
7from autograd.scipy.stats import dirichlet
8from autograd import hessian
9
10from ssm.util import one_hot, logistic, relu, rle, \
11    fit_multiclass_logistic_regression, \
12    fit_negative_binomial_integer_r, ensure_args_are_lists
13from ssm.stats import multivariate_normal_logpdf
14from ssm.optimizers import adam, bfgs, lbfgs, rmsprop, sgd
15
16
17class Transitions(object):
18    def __init__(self, K, D, M=0):
19        self.K, self.D, self.M = K, D, M
20        self.type_name = self.__class__.__name__
21
22    @property
23    def params(self):
24        raise NotImplementedError
25
26    @params.setter
27    def params(self, value):
28        raise NotImplementedError
29
30    @ensure_args_are_lists
31    def initialize(self, datas, inputs=None, masks=None, tags=None):
32        pass
33
34    def permute(self, perm):
35        pass
36
37    def log_prior(self):
38        return 0
39
40    def log_transition_matrices(self, data, input, mask, tag):
41        raise NotImplementedError
42
43    def m_step(self, expectations, datas, inputs, masks, tags,
44               optimizer="lbfgs", num_iters=100, **kwargs):
45        """
46        If M-step cannot be done in closed form for the transitions, default to BFGS.
47        """
48        optimizer = dict(sgd=sgd, adam=adam, rmsprop=rmsprop, bfgs=bfgs, lbfgs=lbfgs)[optimizer]
49
50        # Maximize the expected log joint
51        def _expected_log_joint(expectations):
52            elbo = self.log_prior()
53            for data, input, mask, tag, (expected_states, expected_joints, _) \
54                in zip(datas, inputs, masks, tags, expectations):
55                log_Ps = self.log_transition_matrices(data, input, mask, tag)
56                elbo += np.sum(expected_joints * log_Ps)
57            return elbo
58
59        # Normalize and negate for minimization
60        T = sum([data.shape[0] for data in datas])
61        def _objective(params, itr):
62            self.params = params
63            obj = _expected_log_joint(expectations)
64            return -obj / T
65
66        # Call the optimizer. Persist state (e.g. SGD momentum) across calls to m_step.
67        optimizer_state = self.optimizer_state if hasattr(self, "optimizer_state") else None
68        self.params, self.optimizer_state = \
69            optimizer(_objective, self.params, num_iters=num_iters,
70                      state=optimizer_state, full_output=True, **kwargs)
71
72    def hessian_expected_log_trans_prob(self, data, input, mask, tag, expected_joints):
73        # Return (T-1, D, D) array of blocks for the diagonal of the Hessian
74        warn("Analytical Hessian is not implemented for this transition class. \
75              Optimization via Laplace-EM may be slow. Consider using an \
76              alternative posterior and inference method.")
77        T, D = data.shape
78        obj = lambda x, E_zzp1: np.sum(E_zzp1 * self.log_transition_matrices(x, input, mask, tag))
79        hess = hessian(obj)
80        terms = np.array([hess(x[None,:], Ezzp1) for x, Ezzp1 in zip(data, expected_joints)])
81        return terms
82
83class StationaryTransitions(Transitions):
84    """
85    Standard Hidden Markov Model with fixed initial distribution and transition matrix.
86    """
87    def __init__(self, K, D, M=0):
88        super(StationaryTransitions, self).__init__(K, D, M=M)
89        Ps = .95 * np.eye(K) + .05 * npr.rand(K, K)
90        Ps /= Ps.sum(axis=1, keepdims=True)
91        self.log_Ps = np.log(Ps)
92
93    @property
94    def params(self):
95        return (self.log_Ps,)
96
97    @params.setter
98    def params(self, value):
99        self.log_Ps = value[0]
100
101    def permute(self, perm):
102        """
103        Permute the discrete latent states.
104        """
105        self.log_Ps = self.log_Ps[np.ix_(perm, perm)]
106
107    @property
108    def transition_matrix(self):
109        return np.exp(self.log_Ps - logsumexp(self.log_Ps, axis=1, keepdims=True))
110
111    def log_transition_matrices(self, data, input, mask, tag):
112        T = data.shape[0]
113        log_Ps = self.log_Ps - logsumexp(self.log_Ps, axis=1, keepdims=True)
114        # return np.tile(log_Ps[None, :, :], (T-1, 1, 1))
115        return log_Ps[None, :, :]
116
117    def m_step(self, expectations, datas, inputs, masks, tags, **kwargs):
118        P = sum([np.sum(Ezzp1, axis=0) for _, Ezzp1, _ in expectations]) + 1e-16
119        P /= P.sum(axis=-1, keepdims=True)
120        self.log_Ps = np.log(P)
121
122    def hessian_expected_log_trans_prob(self, data, input, mask, tag, expected_joints):
123        # Return (T-1, D, D) array of blocks for the diagonal of the Hessian
124        T, D = data.shape
125        return np.zeros((T-1, D, D))
126
127class StickyTransitions(StationaryTransitions):
128    """
129    Upweight the self transition prior.
130
131    pi_k ~ Dir(alpha + kappa * e_k)
132    """
133    def __init__(self, K, D, M=0, alpha=1, kappa=100):
134        super(StickyTransitions, self).__init__(K, D, M=M)
135        self.alpha = alpha
136        self.kappa = kappa
137
138    def log_prior(self):
139        K = self.K
140        Ps = np.exp(self.log_Ps - logsumexp(self.log_Ps, axis=1, keepdims=True))
141
142        lp = 0
143        for k in range(K):
144            alpha = self.alpha * np.ones(K) + self.kappa * (np.arange(K) == k)
145            lp += dirichlet.logpdf(Ps[k], alpha)
146        return lp
147
148    def m_step(self, expectations, datas, inputs, masks, tags, **kwargs):
149        expected_joints = sum([np.sum(Ezzp1, axis=0) for _, Ezzp1, _ in expectations]) + 1e-8
150        expected_joints += self.kappa * np.eye(self.K)
151        P = expected_joints / expected_joints.sum(axis=1, keepdims=True)
152        self.log_Ps = np.log(P)
153
154    def hessian_expected_log_trans_prob(self, data, input, mask, tag, expected_joints):
155        # Return (T-1, D, D) array of blocks for the diagonal of the Hessian
156        T, D = data.shape
157        return np.zeros((T-1, D, D))
158
159class InputDrivenTransitions(StickyTransitions):
160    """
161    Hidden Markov Model whose transition probabilities are
162    determined by a generalized linear model applied to the
163    exogenous input.
164    """
165    def __init__(self, K, D, M, alpha=1, kappa=0, l2_penalty=0.0):
166        super(InputDrivenTransitions, self).__init__(K, D, M=M, alpha=alpha, kappa=kappa)
167
168        # Parameters linking input to state distribution
169        self.Ws = npr.randn(K, M)
170
171        # Regularization of Ws
172        self.l2_penalty = l2_penalty
173
174    @property
175    def params(self):
176        return self.log_Ps, self.Ws
177
178    @params.setter
179    def params(self, value):
180        self.log_Ps, self.Ws = value
181
182    def permute(self, perm):
183        """
184        Permute the discrete latent states.
185        """
186        self.log_Ps = self.log_Ps[np.ix_(perm, perm)]
187        self.Ws = self.Ws[perm]
188
189    def log_prior(self):
190        lp = super(InputDrivenTransitions, self).log_prior()
191        lp = lp + np.sum(-0.5 * self.l2_penalty * self.Ws**2)
192        return lp
193
194    def log_transition_matrices(self, data, input, mask, tag):
195        T = data.shape[0]
196        assert input.shape[0] == T
197        # Previous state effect
198        log_Ps = np.tile(self.log_Ps[None, :, :], (T-1, 1, 1))
199        # Input effect
200        log_Ps = log_Ps + np.dot(input[1:], self.Ws.T)[:, None, :]
201        return log_Ps - logsumexp(log_Ps, axis=2, keepdims=True)
202
203    def m_step(self, expectations, datas, inputs, masks, tags, **kwargs):
204        Transitions.m_step(self, expectations, datas, inputs, masks, tags, **kwargs)
205
206    def hessian_expected_log_trans_prob(self, data, input, mask, tag, expected_joints):
207        # Return (T-1, D, D) array of blocks for the diagonal of the Hessian
208        T, D = data.shape
209        return np.zeros((T-1, D, D))
210
211class RecurrentTransitions(InputDrivenTransitions):
212    """
213    Generalization of the input driven HMM in which the observations serve as future inputs
214    """
215    def __init__(self, K, D, M=0, alpha=1, kappa=0):
216        super(RecurrentTransitions, self).__init__(K, D, M, alpha=alpha, kappa=kappa)
217
218        # Parameters linking past observations to state distribution
219        self.Rs = np.zeros((K, D))
220
221    @property
222    def params(self):
223        return super(RecurrentTransitions, self).params + (self.Rs,)
224
225    @params.setter
226    def params(self, value):
227        self.Rs = value[-1]
228        super(RecurrentTransitions, self.__class__).params.fset(self, value[:-1])
229
230    def permute(self, perm):
231        """
232        Permute the discrete latent states.
233        """
234        super(RecurrentTransitions, self).permute(perm)
235        self.Rs = self.Rs[perm]
236
237    def log_transition_matrices(self, data, input, mask, tag):
238        T, D = data.shape
239        # Previous state effect
240        log_Ps = np.tile(self.log_Ps[None, :, :], (T-1, 1, 1))
241        # Input effect
242        log_Ps = log_Ps + np.dot(input[1:], self.Ws.T)[:, None, :]
243        # Past observations effect
244        log_Ps = log_Ps + np.dot(data[:-1], self.Rs.T)[:, None, :]
245        return log_Ps - logsumexp(log_Ps, axis=2, keepdims=True)
246
247    def m_step(self, expectations, datas, inputs, masks, tags, **kwargs):
248        Transitions.m_step(self, expectations, datas, inputs, masks, tags, **kwargs)
249
250    def hessian_expected_log_trans_prob(self, data, input, mask, tag, expected_joints):
251        # Return (T-1, D, D) array of blocks for the diagonal of the Hessian
252        T, D = data.shape
253        hess = np.zeros((T-1,D,D))
254        vtildes = np.exp(self.log_transition_matrices(data, input, mask, tag)) # normalized probabilities
255        Ez = np.sum(expected_joints, axis=2) # marginal over z from T=1 to T-1
256        for k in range(self.K):
257            vtilde = vtildes[:,k,:] # normalized probabilities given state k
258            Rv = vtilde@self.Rs
259            hess += Ez[:,k][:,None,None] * \
260                    ( np.einsum('tn, ni, nj ->tij', -vtilde, self.Rs, self.Rs) \
261                    + np.einsum('ti, tj -> tij', Rv, Rv))
262        return hess
263
264class RecurrentOnlyTransitions(Transitions):
265    """
266    Only allow the past observations and inputs to influence the
267    next state.  Get rid of the transition matrix and replace it
268    with a constant bias r.
269    """
270    def __init__(self, K, D, M=0):
271        super(RecurrentOnlyTransitions, self).__init__(K, D, M)
272
273        # Parameters linking past observations to state distribution
274        self.Ws = npr.randn(K, M)
275        self.Rs = npr.randn(K, D)
276        self.r = npr.randn(K)
277
278    @property
279    def params(self):
280        return self.Ws, self.Rs, self.r
281
282    @params.setter
283    def params(self, value):
284        self.Ws, self.Rs, self.r = value
285
286    def permute(self, perm):
287        """
288        Permute the discrete latent states.
289        """
290        self.Ws = self.Ws[perm]
291        self.Rs = self.Rs[perm]
292        self.r = self.r[perm]
293
294    def log_transition_matrices(self, data, input, mask, tag):
295        T, D = data.shape
296        log_Ps = np.dot(input[1:], self.Ws.T)[:, None, :]              # inputs
297        log_Ps = log_Ps + np.dot(data[:-1], self.Rs.T)[:, None, :]     # past observations
298        log_Ps = log_Ps + self.r                                       # bias
299        log_Ps = np.tile(log_Ps, (1, self.K, 1))                       # expand
300        return log_Ps - logsumexp(log_Ps, axis=2, keepdims=True)       # normalize
301
302    def m_step(self, expectations, datas, inputs, masks, tags, **kwargs):
303        Transitions.m_step(self, expectations, datas, inputs, masks, tags, **kwargs)
304
305    def hessian_expected_log_trans_prob(self, data, input, mask, tag, expected_joints):
306        # Return (T-1, D, D) array of blocks for the diagonal of the Hessian
307        T, D = data.shape
308        v = np.dot(input[1:], self.Ws.T) + np.dot(data[:-1], self.Rs.T) + self.r
309        shifted_exp = np.exp(v - np.max(v,axis=1,keepdims=True))
310        vtilde = shifted_exp / np.sum(shifted_exp,axis=1,keepdims=True) # normalized probabilities
311        Rv = vtilde@self.Rs
312        return np.einsum('tn, ni, nj ->tij', -vtilde, self.Rs, self.Rs) \
313               + np.einsum('ti, tj -> tij', Rv, Rv)
314
315class RBFRecurrentTransitions(InputDrivenTransitions):
316    """
317    Recurrent transitions with radial basis functions for parameterizing
318    the next state probability given current continuous data. We have,
319
320    p(z_{t+1} = k | z_t, x_t)
321        \propto N(x_t | \mu_k, \Sigma_k) \times \pi_{z_t, z_{t+1})
322
323    where {\mu_k, \Sigma_k, \pi_k}_{k=1}^K are learned parameters.
324    Equivalently,
325
326    log p(z_{t+1} = k | z_t, x_t)
327        = log N(x_t | \mu_k, \Sigma_k) + log \pi_{z_t, z_{t+1}) + const
328        = -D/2 log(2\pi) -1/2 log |Sigma_k|
329          -1/2 (x - \mu_k)^T \Sigma_k^{-1} (x-\mu_k)
330          + log \pi{z_t, z_{t+1}}
331
332    The difference between this and the recurrent model above is that the
333    log transition matrices are quadratic functions of x rather than linear.
334
335    While we're at it, there's no harm in adding a linear term to the log
336    transition matrices to capture input dependencies.
337    """
338    def __init__(self, K, D, M=0, alpha=1, kappa=0):
339        super(RBFRecurrentTransitions, self).__init__(K, D, M=M, alpha=alpha, kappa=kappa)
340
341        # RBF parameters
342        self.mus = npr.randn(K, D)
343        self._sqrt_Sigmas = npr.randn(K, D, D)
344
345    @property
346    def params(self):
347        return self.log_Ps, self.mus, self._sqrt_Sigmas, self.Ws
348
349    @params.setter
350    def params(self, value):
351        self.log_Ps, self.mus, self._sqrt_Sigmas, self.Ws = value
352
353    @property
354    def Sigmas(self):
355        return np.matmul(self._sqrt_Sigmas, np.swapaxes(self._sqrt_Sigmas, -1, -2))
356
357    @ensure_args_are_lists
358    def initialize(self, datas, inputs=None, masks=None, tags=None):
359        # Fit a GMM to the data to set the means and covariances
360        from sklearn.mixture import GaussianMixture
361        gmm = GaussianMixture(self.K, covariance_type="full")
362        gmm.fit(np.vstack(datas))
363        self.mus = gmm.means_
364        self._sqrt_Sigmas = np.linalg.cholesky(gmm.covariances_)
365
366    def permute(self, perm):
367        """
368        Permute the discrete latent states.
369        """
370        self.log_Ps = self.log_Ps[np.ix_(perm, perm)]
371        self.mus = self.mus[perm]
372        self.sqrt_Sigmas = self.sqrt_Sigmas[perm]
373        self.Ws = self.Ws[perm]
374
375    def log_transition_matrices(self, data, input, mask, tag):
376        assert np.all(mask), "Recurrent models require that all data are present."
377
378        T = data.shape[0]
379        assert input.shape[0] == T
380        K, D = self.K, self.D
381
382        # Previous state effect
383        log_Ps = np.tile(self.log_Ps[None, :, :], (T-1, 1, 1))
384
385        # RBF recurrent function
386        rbf = multivariate_normal_logpdf(data[:-1, None, :], self.mus, self.Sigmas)
387        log_Ps = log_Ps + rbf[:, None, :]
388
389        # Input effect
390        log_Ps = log_Ps + np.dot(input[1:], self.Ws.T)[:, None, :]
391        return log_Ps - logsumexp(log_Ps, axis=2, keepdims=True)
392
393    def m_step(self, expectations, datas, inputs, masks, tags, **kwargs):
394        Transitions.m_step(self, expectations, datas, inputs, masks, tags, **kwargs)
395
396
397# Allow general nonlinear emission models with neural networks
398class NeuralNetworkRecurrentTransitions(Transitions):
399    def __init__(self, K, D, M=0, hidden_layer_sizes=(50,), nonlinearity="relu"):
400        super(NeuralNetworkRecurrentTransitions, self).__init__(K, D, M=M)
401
402        # Baseline transition probabilities
403        Ps = .95 * np.eye(K) + .05 * npr.rand(K, K)
404        Ps /= Ps.sum(axis=1, keepdims=True)
405        self.log_Ps = np.log(Ps)
406
407        # Initialize the NN weights
408        layer_sizes = (D + M,) + hidden_layer_sizes + (K,)
409        self.weights = [npr.randn(m, n) for m, n in zip(layer_sizes[:-1], layer_sizes[1:])]
410        self.biases = [npr.randn(n) for n in layer_sizes[1:]]
411
412        nonlinearities = dict(
413            relu=relu,
414            tanh=np.tanh,
415            sigmoid=logistic)
416        self.nonlinearity = nonlinearities[nonlinearity]
417
418    @property
419    def params(self):
420        return self.log_Ps, self.weights, self.biases
421
422    @params.setter
423    def params(self, value):
424        self.log_Ps, self.weights, self.biases = value
425
426    def permute(self, perm):
427        self.log_Ps = self.log_Ps[np.ix_(perm, perm)]
428        self.weights[-1] = self.weights[-1][:,perm]
429        self.biases[-1] = self.biases[-1][perm]
430
431    def log_transition_matrices(self, data, input, mask, tag):
432        # Pass the data and inputs through the neural network
433        x = np.hstack((data[:-1], input[1:]))
434        for W, b in zip(self.weights, self.biases):
435            y = np.dot(x, W) + b
436            x = self.nonlinearity(y)
437
438        # Add the baseline transition biases
439        log_Ps = self.log_Ps[None, :, :] + y[:, None, :]
440
441        # Normalize
442        return log_Ps - logsumexp(log_Ps, axis=2, keepdims=True)
443
444    def m_step(self, expectations, datas, inputs, masks, tags, optimizer="adam", num_iters=100, **kwargs):
445        # Default to adam instead of bfgs for the neural network model.
446        Transitions.m_step(self, expectations, datas, inputs, masks, tags,
447            optimizer=optimizer, num_iters=num_iters, **kwargs)
448
449
450class NegativeBinomialSemiMarkovTransitions(Transitions):
451    """
452    Semi-Markov transition model with negative binomial (NB) distributed
453    state durations, as compared to the geometric state durations in the
454    standard Markov model.  The negative binomial has higher variance than
455    the geometric, but its mode can be greater than 1.
456
457    The NB(r, p) distribution, with r a positive integer and p a probability
458    in [0, 1], is this distribution over number of heads before seeing
459    r tails where the probability of heads is p. The number of heads
460    between each tails is an independent geometric random variable.  Thus,
461    the total number of heads is the sum of r independent and identically
462    distributed geometric random variables.
463
464    We can "embed" the semi-Markov model with negative binomial durations
465    in the standard Markov model by expanding the state space.  Map each
466    discrete state k to r new states: (k,1), (k,2), ..., (k,r_k),
467    for k in 1, ..., K. The total number of states is \sum_k r_k,
468    where state k has a NB(r_k, p_k) duration distribution.
469
470    The transition probabilities are as follows. The probability of staying
471    within the same "super state" are:
472
473    p(z_{t+1} = (k,i) | z_t = (k,i)) = p_k
474
475    and for 0 <= j <= r_k - i
476
477    p(z_{t+1} = (k,i+j) | z_t = (k,i)) = (1-p_k)^{j-i} p_k
478
479    The probability of flipping (r_k - i + 1) tails in a row in state k;
480    i.e. the probability of exiting super state k, is (1-p_k)^{r_k-i+1}.
481    Thus, the probability of transitioning to a new super state is:
482
483    p(z_{t+1} = (j,1) | z_t = (k,i)) = (1-p_k)^{r_k-i+1} * P[k, j]
484
485    where P[k, j] is a transition matrix with zero diagonal.
486
487    As a sanity check, note that the sum of probabilities is indeed 1:
488
489    \sum_{j=i}^{r_k} p(z_{t+1} = (k,j) | z_t = (k,i))
490        + \sum_{m \neq k}  p(z_{t+1} = (m, 1) | z_t = (k, i))
491
492    = \sum_{j=0}^{r_k-i} (1-p_k)^j p_k + \sum_{m \neq k} (1-p_k)^{r_k-i+1} * P[k, j]
493
494    = p_k (1-(1-p_k)^{r_k-i+1}) / (1-(1-p_k)) + (1-p_k)^{r_k-i+1}
495
496    = 1 - (1-p_k)^{r_k-i+1} + (1 - p_k)^{r_k-i+1}
497
498    = 1.
499
500    where we used the geometric series and the fact that \sum_{j != k} P[k, j] = 1.
501    """
502    def __init__(self, K, D, M=0, r_min=1, r_max=20):
503        assert K > 1, "Explicit duration models only work if num states > 1."
504        super(NegativeBinomialSemiMarkovTransitions, self).__init__(K, D, M=M)
505
506        # Initialize the super state transition probabilities
507        self.Ps = npr.rand(K, K)
508        np.fill_diagonal(self.Ps, 0)
509        self.Ps /= self.Ps.sum(axis=1, keepdims=True)
510
511        # Initialize the negative binomial duration probabilities
512        self.r_min, self.r_max = r_min, r_max
513        self.rs = npr.randint(r_min, r_max + 1, size=K)
514        # self.rs = np.ones(K, dtype=int)
515        # self.ps = npr.rand(K)
516        self.ps = 0.5 * np.ones(K)
517
518        # Initialize the transition matrix
519        self._transition_matrix = None
520
521    @property
522    def params(self):
523        return (self.Ps, self.rs, self.ps)
524
525    @params.setter
526    def params(self, value):
527        Ps, rs, ps = value
528        assert Ps.shape == (self.K, self.K)
529        assert np.allclose(np.diag(Ps), 0)
530        assert np.allclose(Ps.sum(1), 1)
531        assert rs.shape == (self.K)
532        assert rs.dtype == int
533        assert np.all(rs > 0)
534        assert ps.shape == (self.K)
535        assert np.all(ps > 0)
536        assert np.all(ps < 1)
537        self.Ps, self.rs, self.ps = Ps, rs, ps
538
539        # Reset the transition matrix
540        self._transition_matrix = None
541
542    def permute(self, perm):
543        """
544        Permute the discrete latent states.
545        """
546        self.Ps = self.Ps[np.ix_(perm, perm)]
547        self.rs = self.rs[perm]
548        self.ps = self.ps[perm]
549
550        # Reset the transition matrix
551        self._transition_matrix = None
552
553    @property
554    def total_num_states(self):
555        return np.sum(self.rs)
556
557    @property
558    def state_map(self):
559        return np.repeat(np.arange(self.K), self.rs)
560
561    @property
562    def transition_matrix(self):
563        if self._transition_matrix is not None:
564            return self._transition_matrix
565
566        As, rs, ps = self.Ps, self.rs, self.ps
567
568        # Fill in the transition matrix one block at a time
569        K_total = self.total_num_states
570        P = np.zeros((K_total, K_total))
571        starts = np.concatenate(([0], np.cumsum(rs)[:-1]))
572        ends = np.cumsum(rs)
573        for (i, j), Aij in np.ndenumerate(As):
574            block = P[starts[i]:ends[i], starts[j]:ends[j]]
575
576            # Diagonal blocks (stay in sub-state or advance to next sub-state)
577            if i == j:
578                for k in range(rs[i]):
579                    # p(z_{t+1} = (.,i+k) | z_t = (.,i)) = (1-p)^k p
580                    # for 0 <= k <= r - i
581                    block += (1 - ps[i])**k * ps[i] * np.diag(np.ones(rs[i]-k), k=k)
582
583            # Off-diagonal blocks (exit to a new super state)
584            else:
585                # p(z_{t+1} = (j,1) | z_t = (k,i)) = (1-p_k)^{r_k-i+1} * A[k, j]
586                block[:,0] = (1-ps[i]) ** np.arange(rs[i], 0, -1) * Aij
587
588        assert np.allclose(P.sum(1),1)
589        assert (0 <= P).all() and (P <= 1.).all()
590
591        # Cache the transition matrix
592        self._transition_matrix = P
593
594        return P
595
596    def log_transition_matrices(self, data, input, mask, tag):
597        T = data.shape[0]
598        P = self.transition_matrix
599        return np.tile(np.log(P)[None, :, :], (T-1, 1, 1))
600
601    def m_step(self, expectations, datas, inputs, masks, tags, samples, **kwargs):
602        # Update the transition matrix between super states
603        P = sum([np.sum(Ezzp1, axis=0) for _, Ezzp1, _ in expectations]) + 1e-16
604        np.fill_diagonal(P, 0)
605        P /= P.sum(axis=-1, keepdims=True)
606        self.Ps = P
607
608        # Fit negative binomial models for each duration based on sampled states
609        states, durations = map(np.concatenate, zip(*[rle(z_smpl) for z_smpl in samples]))
610        for k in range(self.K):
611            self.rs[k], self.ps[k] = \
612                fit_negative_binomial_integer_r(durations[states == k], self.r_min, self.r_max)
613
614        # Reset the transition matrix
615        self._transition_matrix = None
616