1# Author: Mathieu Blondel, Tom Dupre la Tour
2# License: BSD 3 clause
3
4from cython cimport floating
5from libc.math cimport fabs
6
7
8def _update_cdnmf_fast(floating[:, ::1] W, floating[:, :] HHt,
9                       floating[:, :] XHt, Py_ssize_t[::1] permutation):
10    cdef:
11        floating violation = 0
12        Py_ssize_t n_components = W.shape[1]
13        Py_ssize_t n_samples = W.shape[0]  # n_features for H update
14        floating grad, pg, hess
15        Py_ssize_t i, r, s, t
16
17    with nogil:
18        for s in range(n_components):
19            t = permutation[s]
20
21            for i in range(n_samples):
22                # gradient = GW[t, i] where GW = np.dot(W, HHt) - XHt
23                grad = -XHt[i, t]
24
25                for r in range(n_components):
26                    grad += HHt[t, r] * W[i, r]
27
28                # projected gradient
29                pg = min(0., grad) if W[i, t] == 0 else grad
30                violation += fabs(pg)
31
32                # Hessian
33                hess = HHt[t, t]
34
35                if hess != 0:
36                    W[i, t] = max(W[i, t] - grad / hess, 0.)
37
38    return violation
39