1# distutils: extra_compile_args = -O3
2# cython: wraparound=False
3# cython: boundscheck=False
4# cython: nonecheck=False
5# cython: cdivision=True
6
7import numpy as np
8cimport numpy as np
9
10from libc.math cimport log, exp, fmax, INFINITY
11
12cdef double logsumexp(double[::1] x) nogil:
13    cdef int i, N
14    cdef double m, out
15
16    N = x.shape[0]
17
18    # find the max
19    m = -INFINITY
20    for i in range(N):
21        m = fmax(m, x[i])
22
23    # sum the exponentials
24    out = 0
25    for i in range(N):
26        out += exp(x[i] - m)
27
28    return m + log(out)
29
30
31cdef dlse(double[::1] a,
32          double[::1] out):
33
34    cdef int K, k
35    K = a.shape[0]
36    cdef double lse = logsumexp(a)
37
38    for k in range(K):
39        out[k] = exp(a[k] - lse)
40
41
42cpdef forward_pass(double[::1] log_pi0,
43                   double[:,:,::1] log_Ps,
44                   double[:,::1] log_likes,
45                   double[:,::1] alphas):
46
47    cdef int T, K, t, k, j
48    T = log_likes.shape[0]
49    K = log_likes.shape[1]
50    assert log_Ps.shape[0] == T-1 or log_Ps.shape[0] == 1
51    assert log_Ps.shape[1] == K
52    assert log_Ps.shape[2] == K
53    assert alphas.shape[0] == T
54    assert alphas.shape[1] == K
55
56    cdef double[::1] tmp = np.zeros(K)
57
58    # Trick for handling time-varying transition matrices
59    cdef int hetero = (log_Ps.shape[0] == T-1)
60
61    for k in range(K):
62        alphas[0, k] = log_pi0[k] + log_likes[0, k]
63
64    for t in range(T - 1):
65        for k in range(K):
66            for j in range(K):
67                tmp[j] = alphas[t, j] + log_Ps[t * hetero, j, k]
68            alphas[t+1, k] = logsumexp(tmp) + log_likes[t+1, k]
69
70    return logsumexp(alphas[T-1])
71
72cpdef backward_pass(double[:,:,::1] log_Ps,
73                    double[:,::1] log_likes,
74                    double[:,::1] betas):
75
76    cdef int T, K, t, k, j
77    T = log_likes.shape[0]
78    K = log_likes.shape[1]
79    assert log_Ps.shape[0] == T-1 or log_Ps.shape[0] == 1
80    assert log_Ps.shape[1] == K
81    assert log_Ps.shape[2] == K
82    assert betas.shape[0] == T
83    assert betas.shape[1] == K
84
85    cdef double[::1] tmp = np.zeros(K)
86
87    # Trick for handling time-varying transition matrices
88    cdef int hetero = (log_Ps.shape[0] == T-1)
89
90    # Initialize the last output
91    for k in range(K):
92        betas[T-1, k] = 0
93
94    for t in range(T-2,-1,-1):
95        # betal[t] = logsumexp(Al + betal[t+1] + aBl[t+1],axis=1)
96        for k in range(K):
97            for j in range(K):
98                tmp[j] = log_Ps[t * hetero, k, j] + betas[t+1, j] + log_likes[t+1, j]
99            betas[t, k] = logsumexp(tmp)
100
101
102cpdef backward_sample(double[:,:,::1] log_Ps,
103                      double[:,::1] log_likes,
104                      double[:,::1] alphas,
105                      double[::1] us,
106                      long[::1] zs):
107
108    cdef int T, K, t, k, j
109    cdef double Z, acc
110
111    T = log_likes.shape[0]
112    K = log_likes.shape[1]
113    assert log_Ps.shape[0] == T-1 or log_Ps.shape[0] == 1
114    assert log_Ps.shape[1] == K
115    assert log_Ps.shape[2] == K
116    assert alphas.shape[0] == T
117    assert alphas.shape[1] == K
118    assert us.shape[0] == T
119    assert zs.shape[0] == T
120
121    cdef double[::1] lpzp1 = np.zeros(K)
122    cdef double[::1] lpz = np.zeros(K)
123
124    # Trick for handling time-varying transition matrices
125    cdef int hetero = (log_Ps.shape[0] == T-1)
126
127    for t in range(T-1,-1,-1):
128        # compute normalized log p(z[t] = k | z[t+1])
129        for k in range(K):
130            lpz[k] = lpzp1[k] + alphas[t, k]
131        Z = logsumexp(lpz)
132
133        # sample
134        acc = 0
135        zs[t] = K-1
136        for k in range(K):
137            acc += np.exp(lpz[k] - Z)
138            if us[t] < acc:
139                zs[t] = k
140                break
141
142        # set the transition potential
143        if t > 0:
144            for k in range(K):
145                lpzp1[k] = log_Ps[(t-1) * hetero, k, zs[t]]
146
147
148cpdef grad_hmm_normalizer(double[:,:,::1] log_Ps,
149                          double[:,::1] alphas,
150                          double[::1] d_log_pi0,
151                          double[:,:,::1] d_log_Ps,
152                          double[:,::1] d_log_likes):
153
154    cdef int T, K, t, k, j
155
156    T = alphas.shape[0]
157    K = alphas.shape[1]
158    assert (log_Ps.shape[0] == T-1) or (log_Ps.shape[0] == 1)
159    assert d_log_Ps.shape[0] == log_Ps.shape[0]
160    assert log_Ps.shape[1] == d_log_Ps.shape[1] == K
161    assert log_Ps.shape[2] == d_log_Ps.shape[2] == K
162    assert d_log_pi0.shape[0] == K
163    assert d_log_likes.shape[0] == T
164    assert d_log_likes.shape[1] == K
165
166    # Initialize temp storage for gradients
167    cdef double[::1] tmp1 = np.zeros((K,))
168    cdef double[:, ::1] tmp2 = np.zeros((K, K))
169
170    # Trick for handling time-varying transition matrices
171    cdef int hetero = (log_Ps.shape[0] == T-1)
172
173    dlse(alphas[T-1], d_log_likes[T-1])
174    for t in range(T-1, 0, -1):
175        # tmp2 = dLSE_da(alphas[t-1], log_Ps[t-1])
176        #      = np.exp(alphas[t-1] + log_Ps[t-1].T - logsumexp(alphas[t-1] + log_Ps[t-1].T, axis=1))
177        #      = [dlse(alphas[t-1] + log_Ps[t-1, :, k]) for k in range(K)]
178        for k in range(K):
179            for j in range(K):
180                tmp1[j] = alphas[t-1, j] + log_Ps[(t-1) * hetero, j, k]
181            dlse(tmp1, tmp2[k])
182
183
184        # d_log_Ps[t-1] = vjp_LSE_B(alphas[t-1], log_Ps[t-1], d_log_likes[t])
185        #               = d_log_likes[t] * dLSE_da(alphas[t-1], log_Ps[t-1]).T
186        #               = d_log_likes[t] * tmp2.T
187        #
188        # d_log_Ps[t-1, j, k] = d_log_likes[t, k] * tmp2.T[j, k]
189        #                     = d_log_likes[t, k] * tmp2[k, j]
190        for j in range(K):
191            for k in range(K):
192                d_log_Ps[(t-1) * hetero, j, k] += d_log_likes[t, k] * tmp2[k, j]
193
194        # d_log_likes[t-1] = d_log_likes[t].dot(dLSE_da(alphas[t-1], log_Ps[t-1]))
195        #                  = d_log_likes[t].dot(tmp2)
196        for k in range(K):
197            d_log_likes[t-1, k] = 0
198            for j in range(K):
199                d_log_likes[t-1, k] += d_log_likes[t, j] * tmp2[j, k]
200
201    # d_log_pi0 = d_log_likes[0]
202    for k in range(K):
203        d_log_pi0[k] = d_log_likes[0, k]
204
205
206cpdef compute_stationary_expected_joints(
207    double[:,::1] alphas,
208    double[:,::1] betas,
209    double[:,::1] lls,
210    double[:,::1] log_P,
211    double[:,::1] E_zzp1):
212
213    cdef int T = alphas.shape[0]
214    cdef int K = alphas.shape[1]
215    assert betas.shape[0] == T and betas.shape[1] == K
216    assert lls.shape[0] == T and lls.shape[1] == K
217    assert log_P.shape[0] == K and log_P.shape[1] == K
218    assert E_zzp1.shape[0] == K and E_zzp1.shape[1] == K
219
220    cdef int i, j, t
221    cdef double maxv, tmpsum
222    cdef double[:, ::1] tmp = np.zeros((K, K))
223
224    # Compute the sum over time axis of the expected joints
225    for t in range(T-1):
226        maxv = -INFINITY
227        for i in range(K):
228            for j in range(K):
229                # Compute expectations in this batch
230                tmp[i, j] = alphas[t,i] + betas[t+1,j] + lls[t+1,j] + log_P[i, j]
231                if tmp[i, j] > maxv:
232                    maxv = tmp[i, j]
233
234        # safe exponentiate
235        tmpsum = 0.0
236        for i in range(K):
237            for j in range(K):
238                tmp[i, j] = exp(tmp[i, j] - maxv)
239                tmpsum += tmp[i, j]
240
241        # Add to expected joints
242        for i in range(K):
243            for j in range(K):
244                E_zzp1[i, j] += tmp[i, j] / (tmpsum + 1e-16)
245