1# Author: Mathieu Blondel, Tom Dupre la Tour 2# License: BSD 3 clause 3 4from cython cimport floating 5from libc.math cimport fabs 6 7 8def _update_cdnmf_fast(floating[:, ::1] W, floating[:, :] HHt, 9 floating[:, :] XHt, Py_ssize_t[::1] permutation): 10 cdef: 11 floating violation = 0 12 Py_ssize_t n_components = W.shape[1] 13 Py_ssize_t n_samples = W.shape[0] # n_features for H update 14 floating grad, pg, hess 15 Py_ssize_t i, r, s, t 16 17 with nogil: 18 for s in range(n_components): 19 t = permutation[s] 20 21 for i in range(n_samples): 22 # gradient = GW[t, i] where GW = np.dot(W, HHt) - XHt 23 grad = -XHt[i, t] 24 25 for r in range(n_components): 26 grad += HHt[t, r] * W[i, r] 27 28 # projected gradient 29 pg = min(0., grad) if W[i, t] == 0 else grad 30 violation += fabs(pg) 31 32 # Hessian 33 hess = HHt[t, t] 34 35 if hess != 0: 36 W[i, t] = max(W[i, t] - grad / hess, 0.) 37 38 return violation 39