1# Author: Alexandre Gramfort <alexandre.gramfort@inria.fr>
2#         Fabian Pedregosa <fabian.pedregosa@inria.fr>
3#         Olivier Grisel <olivier.grisel@ensta.org>
4#         Alexis Mignon <alexis.mignon@gmail.com>
5#         Manoj Kumar <manojkumarsivaraj334@gmail.com>
6#
7# License: BSD 3 clause
8
9from libc.math cimport fabs
10cimport numpy as np
11import numpy as np
12import numpy.linalg as linalg
13
14cimport cython
15from cpython cimport bool
16from cython cimport floating
17import warnings
18from ..exceptions import ConvergenceWarning
19
20from ..utils._cython_blas cimport (_axpy, _dot, _asum, _ger, _gemv, _nrm2,
21                                   _copy, _scal)
22from ..utils._cython_blas cimport RowMajor, ColMajor, Trans, NoTrans
23
24
25from ..utils._random cimport our_rand_r
26
27ctypedef np.float64_t DOUBLE
28ctypedef np.uint32_t UINT32_t
29
30np.import_array()
31
32# The following two functions are shamelessly copied from the tree code.
33
34cdef enum:
35    # Max value for our rand_r replacement (near the bottom).
36    # We don't use RAND_MAX because it's different across platforms and
37    # particularly tiny on Windows/MSVC.
38    RAND_R_MAX = 0x7FFFFFFF
39
40
41cdef inline UINT32_t rand_int(UINT32_t end, UINT32_t* random_state) nogil:
42    """Generate a random integer in [0; end)."""
43    return our_rand_r(random_state) % end
44
45
46cdef inline floating fmax(floating x, floating y) nogil:
47    if x > y:
48        return x
49    return y
50
51
52cdef inline floating fsign(floating f) nogil:
53    if f == 0:
54        return 0
55    elif f > 0:
56        return 1.0
57    else:
58        return -1.0
59
60
61cdef floating abs_max(int n, floating* a) nogil:
62    """np.max(np.abs(a))"""
63    cdef int i
64    cdef floating m = fabs(a[0])
65    cdef floating d
66    for i in range(1, n):
67        d = fabs(a[i])
68        if d > m:
69            m = d
70    return m
71
72
73cdef floating max(int n, floating* a) nogil:
74    """np.max(a)"""
75    cdef int i
76    cdef floating m = a[0]
77    cdef floating d
78    for i in range(1, n):
79        d = a[i]
80        if d > m:
81            m = d
82    return m
83
84
85cdef floating diff_abs_max(int n, floating* a, floating* b) nogil:
86    """np.max(np.abs(a - b))"""
87    cdef int i
88    cdef floating m = fabs(a[0] - b[0])
89    cdef floating d
90    for i in range(1, n):
91        d = fabs(a[i] - b[i])
92        if d > m:
93            m = d
94    return m
95
96
97def enet_coordinate_descent(floating[::1] w,
98                            floating alpha, floating beta,
99                            floating[::1, :] X,
100                            floating[::1] y,
101                            int max_iter, floating tol,
102                            object rng, bint random=0, bint positive=0):
103    """Cython version of the coordinate descent algorithm
104        for Elastic-Net regression
105
106        We minimize
107
108        (1/2) * norm(y - X w, 2)^2 + alpha norm(w, 1) + (beta/2) norm(w, 2)^2
109
110    """
111
112    if floating is float:
113        dtype = np.float32
114    else:
115        dtype = np.float64
116
117    # get the data information into easy vars
118    cdef unsigned int n_samples = X.shape[0]
119    cdef unsigned int n_features = X.shape[1]
120
121    # compute norms of the columns of X
122    cdef floating[::1] norm_cols_X = np.square(X).sum(axis=0)
123
124    # initial value of the residuals
125    cdef floating[::1] R = np.empty(n_samples, dtype=dtype)
126    cdef floating[::1] XtA = np.empty(n_features, dtype=dtype)
127
128    cdef floating tmp
129    cdef floating w_ii
130    cdef floating d_w_max
131    cdef floating w_max
132    cdef floating d_w_ii
133    cdef floating gap = tol + 1.0
134    cdef floating d_w_tol = tol
135    cdef floating dual_norm_XtA
136    cdef floating R_norm2
137    cdef floating w_norm2
138    cdef floating l1_norm
139    cdef floating const
140    cdef floating A_norm2
141    cdef unsigned int ii
142    cdef unsigned int i
143    cdef unsigned int n_iter = 0
144    cdef unsigned int f_iter
145    cdef UINT32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX)
146    cdef UINT32_t* rand_r_state = &rand_r_state_seed
147
148    if alpha == 0 and beta == 0:
149        warnings.warn("Coordinate descent with no regularization may lead to "
150                      "unexpected results and is discouraged.")
151
152    with nogil:
153        # R = y - np.dot(X, w)
154        _copy(n_samples, &y[0], 1, &R[0], 1)
155        _gemv(ColMajor, NoTrans, n_samples, n_features, -1.0, &X[0, 0],
156              n_samples, &w[0], 1, 1.0, &R[0], 1)
157
158        # tol *= np.dot(y, y)
159        tol *= _dot(n_samples, &y[0], 1, &y[0], 1)
160
161        for n_iter in range(max_iter):
162            w_max = 0.0
163            d_w_max = 0.0
164            for f_iter in range(n_features):  # Loop over coordinates
165                if random:
166                    ii = rand_int(n_features, rand_r_state)
167                else:
168                    ii = f_iter
169
170                if norm_cols_X[ii] == 0.0:
171                    continue
172
173                w_ii = w[ii]  # Store previous value
174
175                if w_ii != 0.0:
176                    # R += w_ii * X[:,ii]
177                    _axpy(n_samples, w_ii, &X[0, ii], 1, &R[0], 1)
178
179                # tmp = (X[:,ii]*R).sum()
180                tmp = _dot(n_samples, &X[0, ii], 1, &R[0], 1)
181
182                if positive and tmp < 0:
183                    w[ii] = 0.0
184                else:
185                    w[ii] = (fsign(tmp) * fmax(fabs(tmp) - alpha, 0)
186                             / (norm_cols_X[ii] + beta))
187
188                if w[ii] != 0.0:
189                    # R -=  w[ii] * X[:,ii] # Update residual
190                    _axpy(n_samples, -w[ii], &X[0, ii], 1, &R[0], 1)
191
192                # update the maximum absolute coefficient update
193                d_w_ii = fabs(w[ii] - w_ii)
194                d_w_max = fmax(d_w_max, d_w_ii)
195
196                w_max = fmax(w_max, fabs(w[ii]))
197
198            if (w_max == 0.0 or
199                d_w_max / w_max < d_w_tol or
200                n_iter == max_iter - 1):
201                # the biggest coordinate update of this iteration was smaller
202                # than the tolerance: check the duality gap as ultimate
203                # stopping criterion
204
205                # XtA = np.dot(X.T, R) - beta * w
206                _copy(n_features, &w[0], 1, &XtA[0], 1)
207                _gemv(ColMajor, Trans,
208                      n_samples, n_features, 1.0, &X[0, 0], n_samples,
209                      &R[0], 1,
210                      -beta, &XtA[0], 1)
211
212                if positive:
213                    dual_norm_XtA = max(n_features, &XtA[0])
214                else:
215                    dual_norm_XtA = abs_max(n_features, &XtA[0])
216
217                # R_norm2 = np.dot(R, R)
218                R_norm2 = _dot(n_samples, &R[0], 1, &R[0], 1)
219
220                # w_norm2 = np.dot(w, w)
221                w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1)
222
223                if (dual_norm_XtA > alpha):
224                    const = alpha / dual_norm_XtA
225                    A_norm2 = R_norm2 * (const ** 2)
226                    gap = 0.5 * (R_norm2 + A_norm2)
227                else:
228                    const = 1.0
229                    gap = R_norm2
230
231                l1_norm = _asum(n_features, &w[0], 1)
232
233                # np.dot(R.T, y)
234                gap += (alpha * l1_norm
235                        - const * _dot(n_samples, &R[0], 1, &y[0], 1)
236                        + 0.5 * beta * (1 + const ** 2) * (w_norm2))
237
238                if gap < tol:
239                    # return if we reached desired tolerance
240                    break
241
242        else:
243            # for/else, runs if for doesn't end with a `break`
244            with gil:
245                message = (
246                    "Objective did not converge. You might want to increase "
247                    "the number of iterations, check the scale of the "
248                    "features or consider increasing regularisation. "
249                    f"Duality gap: {gap:.3e}, tolerance: {tol:.3e}"
250                )
251                if alpha < np.finfo(np.float64).eps:
252                    message += (
253                        " Linear regression models with null weight for the "
254                        "l1 regularization term are more efficiently fitted "
255                        "using one of the solvers implemented in "
256                        "sklearn.linear_model.Ridge/RidgeCV instead."
257                    )
258                warnings.warn(message, ConvergenceWarning)
259
260    return w, gap, tol, n_iter + 1
261
262
263def sparse_enet_coordinate_descent(floating [::1] w,
264                            floating alpha, floating beta,
265                            np.ndarray[floating, ndim=1, mode='c'] X_data,
266                            np.ndarray[int, ndim=1, mode='c'] X_indices,
267                            np.ndarray[int, ndim=1, mode='c'] X_indptr,
268                            np.ndarray[floating, ndim=1] y,
269                            floating[:] X_mean, int max_iter,
270                            floating tol, object rng, bint random=0,
271                            bint positive=0):
272    """Cython version of the coordinate descent algorithm for Elastic-Net
273
274    We minimize:
275
276        (1/2) * norm(y - X w, 2)^2 + alpha norm(w, 1) + (beta/2) * norm(w, 2)^2
277
278    """
279
280    # get the data information into easy vars
281    cdef unsigned int n_samples = y.shape[0]
282    cdef unsigned int n_features = w.shape[0]
283
284    # compute norms of the columns of X
285    cdef unsigned int ii
286    cdef floating[:] norm_cols_X
287
288    cdef unsigned int startptr = X_indptr[0]
289    cdef unsigned int endptr
290
291    # initial value of the residuals
292    cdef floating[:] R = y.copy()
293
294    cdef floating[:] X_T_R
295    cdef floating[:] XtA
296
297    if floating is float:
298        dtype = np.float32
299    else:
300        dtype = np.float64
301
302    norm_cols_X = np.zeros(n_features, dtype=dtype)
303    X_T_R = np.zeros(n_features, dtype=dtype)
304    XtA = np.zeros(n_features, dtype=dtype)
305
306    cdef floating tmp
307    cdef floating w_ii
308    cdef floating d_w_max
309    cdef floating w_max
310    cdef floating d_w_ii
311    cdef floating X_mean_ii
312    cdef floating R_sum = 0.0
313    cdef floating R_norm2
314    cdef floating w_norm2
315    cdef floating A_norm2
316    cdef floating l1_norm
317    cdef floating normalize_sum
318    cdef floating gap = tol + 1.0
319    cdef floating d_w_tol = tol
320    cdef floating dual_norm_XtA
321    cdef unsigned int jj
322    cdef unsigned int n_iter = 0
323    cdef unsigned int f_iter
324    cdef UINT32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX)
325    cdef UINT32_t* rand_r_state = &rand_r_state_seed
326    cdef bint center = False
327
328    with nogil:
329        # center = (X_mean != 0).any()
330        for ii in range(n_features):
331            if X_mean[ii]:
332                center = True
333                break
334
335        for ii in range(n_features):
336            X_mean_ii = X_mean[ii]
337            endptr = X_indptr[ii + 1]
338            normalize_sum = 0.0
339            w_ii = w[ii]
340
341            for jj in range(startptr, endptr):
342                normalize_sum += (X_data[jj] - X_mean_ii) ** 2
343                R[X_indices[jj]] -= X_data[jj] * w_ii
344            norm_cols_X[ii] = normalize_sum + \
345                (n_samples - endptr + startptr) * X_mean_ii ** 2
346
347            if center:
348                for jj in range(n_samples):
349                    R[jj] += X_mean_ii * w_ii
350            startptr = endptr
351
352        # tol *= np.dot(y, y)
353        tol *= _dot(n_samples, &y[0], 1, &y[0], 1)
354
355        for n_iter in range(max_iter):
356
357            w_max = 0.0
358            d_w_max = 0.0
359
360            for f_iter in range(n_features):  # Loop over coordinates
361                if random:
362                    ii = rand_int(n_features, rand_r_state)
363                else:
364                    ii = f_iter
365
366                if norm_cols_X[ii] == 0.0:
367                    continue
368
369                startptr = X_indptr[ii]
370                endptr = X_indptr[ii + 1]
371                w_ii = w[ii]  # Store previous value
372                X_mean_ii = X_mean[ii]
373
374                if w_ii != 0.0:
375                    # R += w_ii * X[:,ii]
376                    for jj in range(startptr, endptr):
377                        R[X_indices[jj]] += X_data[jj] * w_ii
378                    if center:
379                        for jj in range(n_samples):
380                            R[jj] -= X_mean_ii * w_ii
381
382                # tmp = (X[:,ii] * R).sum()
383                tmp = 0.0
384                for jj in range(startptr, endptr):
385                    tmp += R[X_indices[jj]] * X_data[jj]
386
387                if center:
388                    R_sum = 0.0
389                    for jj in range(n_samples):
390                        R_sum += R[jj]
391                    tmp -= R_sum * X_mean_ii
392
393                if positive and tmp < 0.0:
394                    w[ii] = 0.0
395                else:
396                    w[ii] = fsign(tmp) * fmax(fabs(tmp) - alpha, 0) \
397                            / (norm_cols_X[ii] + beta)
398
399                if w[ii] != 0.0:
400                    # R -=  w[ii] * X[:,ii] # Update residual
401                    for jj in range(startptr, endptr):
402                        R[X_indices[jj]] -= X_data[jj] * w[ii]
403
404                    if center:
405                        for jj in range(n_samples):
406                            R[jj] += X_mean_ii * w[ii]
407
408                # update the maximum absolute coefficient update
409                d_w_ii = fabs(w[ii] - w_ii)
410                if d_w_ii > d_w_max:
411                    d_w_max = d_w_ii
412
413                if fabs(w[ii]) > w_max:
414                    w_max = fabs(w[ii])
415
416            if w_max == 0.0 or d_w_max / w_max < d_w_tol or n_iter == max_iter - 1:
417                # the biggest coordinate update of this iteration was smaller than
418                # the tolerance: check the duality gap as ultimate stopping
419                # criterion
420
421                # sparse X.T / dense R dot product
422                if center:
423                    R_sum = 0.0
424                    for jj in range(n_samples):
425                        R_sum += R[jj]
426
427                for ii in range(n_features):
428                    X_T_R[ii] = 0.0
429                    for jj in range(X_indptr[ii], X_indptr[ii + 1]):
430                        X_T_R[ii] += X_data[jj] * R[X_indices[jj]]
431
432                    if center:
433                        X_T_R[ii] -= X_mean[ii] * R_sum
434                    XtA[ii] = X_T_R[ii] - beta * w[ii]
435
436                if positive:
437                    dual_norm_XtA = max(n_features, &XtA[0])
438                else:
439                    dual_norm_XtA = abs_max(n_features, &XtA[0])
440
441                # R_norm2 = np.dot(R, R)
442                R_norm2 = _dot(n_samples, &R[0], 1, &R[0], 1)
443
444                # w_norm2 = np.dot(w, w)
445                w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1)
446                if (dual_norm_XtA > alpha):
447                    const = alpha / dual_norm_XtA
448                    A_norm2 = R_norm2 * const**2
449                    gap = 0.5 * (R_norm2 + A_norm2)
450                else:
451                    const = 1.0
452                    gap = R_norm2
453
454                l1_norm = _asum(n_features, &w[0], 1)
455
456                gap += (alpha * l1_norm - const * _dot(
457                            n_samples,
458                            &R[0], 1,
459                            &y[0], 1
460                            )
461                        + 0.5 * beta * (1 + const ** 2) * w_norm2)
462
463                if gap < tol:
464                    # return if we reached desired tolerance
465                    break
466
467        else:
468            # for/else, runs if for doesn't end with a `break`
469            with gil:
470                warnings.warn("Objective did not converge. You might want to "
471                              "increase the number of iterations. Duality "
472                              "gap: {}, tolerance: {}".format(gap, tol),
473                              ConvergenceWarning)
474
475    return w, gap, tol, n_iter + 1
476
477
478def enet_coordinate_descent_gram(floating[::1] w,
479                                 floating alpha, floating beta,
480                                 np.ndarray[floating, ndim=2, mode='c'] Q,
481                                 np.ndarray[floating, ndim=1, mode='c'] q,
482                                 np.ndarray[floating, ndim=1] y,
483                                 int max_iter, floating tol, object rng,
484                                 bint random=0, bint positive=0):
485    """Cython version of the coordinate descent algorithm
486        for Elastic-Net regression
487
488        We minimize
489
490        (1/2) * w^T Q w - q^T w + alpha norm(w, 1) + (beta/2) * norm(w, 2)^2
491
492        which amount to the Elastic-Net problem when:
493        Q = X^T X (Gram matrix)
494        q = X^T y
495    """
496
497    if floating is float:
498        dtype = np.float32
499    else:
500        dtype = np.float64
501
502    # get the data information into easy vars
503    cdef unsigned int n_samples = y.shape[0]
504    cdef unsigned int n_features = Q.shape[0]
505
506    # initial value "Q w" which will be kept of up to date in the iterations
507    cdef floating[:] H = np.dot(Q, w)
508
509    cdef floating[:] XtA = np.zeros(n_features, dtype=dtype)
510    cdef floating tmp
511    cdef floating w_ii
512    cdef floating d_w_max
513    cdef floating w_max
514    cdef floating d_w_ii
515    cdef floating q_dot_w
516    cdef floating w_norm2
517    cdef floating gap = tol + 1.0
518    cdef floating d_w_tol = tol
519    cdef floating dual_norm_XtA
520    cdef unsigned int ii
521    cdef unsigned int n_iter = 0
522    cdef unsigned int f_iter
523    cdef UINT32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX)
524    cdef UINT32_t* rand_r_state = &rand_r_state_seed
525
526    cdef floating y_norm2 = np.dot(y, y)
527    cdef floating* w_ptr = <floating*>&w[0]
528    cdef floating* Q_ptr = &Q[0, 0]
529    cdef floating* q_ptr = <floating*>q.data
530    cdef floating* H_ptr = &H[0]
531    cdef floating* XtA_ptr = &XtA[0]
532    tol = tol * y_norm2
533
534    if alpha == 0:
535        warnings.warn("Coordinate descent with alpha=0 may lead to unexpected"
536            " results and is discouraged.")
537
538    with nogil:
539        for n_iter in range(max_iter):
540            w_max = 0.0
541            d_w_max = 0.0
542            for f_iter in range(n_features):  # Loop over coordinates
543                if random:
544                    ii = rand_int(n_features, rand_r_state)
545                else:
546                    ii = f_iter
547
548                if Q[ii, ii] == 0.0:
549                    continue
550
551                w_ii = w[ii]  # Store previous value
552
553                if w_ii != 0.0:
554                    # H -= w_ii * Q[ii]
555                    _axpy(n_features, -w_ii, Q_ptr + ii * n_features, 1,
556                          H_ptr, 1)
557
558                tmp = q[ii] - H[ii]
559
560                if positive and tmp < 0:
561                    w[ii] = 0.0
562                else:
563                    w[ii] = fsign(tmp) * fmax(fabs(tmp) - alpha, 0) \
564                        / (Q[ii, ii] + beta)
565
566                if w[ii] != 0.0:
567                    # H +=  w[ii] * Q[ii] # Update H = X.T X w
568                    _axpy(n_features, w[ii], Q_ptr + ii * n_features, 1,
569                          H_ptr, 1)
570
571                # update the maximum absolute coefficient update
572                d_w_ii = fabs(w[ii] - w_ii)
573                if d_w_ii > d_w_max:
574                    d_w_max = d_w_ii
575
576                if fabs(w[ii]) > w_max:
577                    w_max = fabs(w[ii])
578
579            if w_max == 0.0 or d_w_max / w_max < d_w_tol or n_iter == max_iter - 1:
580                # the biggest coordinate update of this iteration was smaller than
581                # the tolerance: check the duality gap as ultimate stopping
582                # criterion
583
584                # q_dot_w = np.dot(w, q)
585                q_dot_w = _dot(n_features, w_ptr, 1, q_ptr, 1)
586
587                for ii in range(n_features):
588                    XtA[ii] = q[ii] - H[ii] - beta * w[ii]
589                if positive:
590                    dual_norm_XtA = max(n_features, XtA_ptr)
591                else:
592                    dual_norm_XtA = abs_max(n_features, XtA_ptr)
593
594                # temp = np.sum(w * H)
595                tmp = 0.0
596                for ii in range(n_features):
597                    tmp += w[ii] * H[ii]
598                R_norm2 = y_norm2 + tmp - 2.0 * q_dot_w
599
600                # w_norm2 = np.dot(w, w)
601                w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1)
602
603                if (dual_norm_XtA > alpha):
604                    const = alpha / dual_norm_XtA
605                    A_norm2 = R_norm2 * (const ** 2)
606                    gap = 0.5 * (R_norm2 + A_norm2)
607                else:
608                    const = 1.0
609                    gap = R_norm2
610
611                # The call to asum is equivalent to the L1 norm of w
612                gap += (alpha * _asum(n_features, &w[0], 1) -
613                        const * y_norm2 +  const * q_dot_w +
614                        0.5 * beta * (1 + const ** 2) * w_norm2)
615
616                if gap < tol:
617                    # return if we reached desired tolerance
618                    break
619
620        else:
621            # for/else, runs if for doesn't end with a `break`
622            with gil:
623                warnings.warn("Objective did not converge. You might want to "
624                              "increase the number of iterations. Duality "
625                              "gap: {}, tolerance: {}".format(gap, tol),
626                              ConvergenceWarning)
627
628    return np.asarray(w), gap, tol, n_iter + 1
629
630
631def enet_coordinate_descent_multi_task(
632        floating[::1, :] W, floating l1_reg, floating l2_reg,
633        np.ndarray[floating, ndim=2, mode='fortran'] X,  # TODO: use views with Cython 3.0
634        np.ndarray[floating, ndim=2, mode='fortran'] Y,  # hopefully with skl 1.0
635        int max_iter, floating tol, object rng, bint random=0):
636    """Cython version of the coordinate descent algorithm
637        for Elastic-Net mult-task regression
638
639        We minimize
640
641        0.5 * norm(Y - X W.T, 2)^2 + l1_reg ||W.T||_21 + 0.5 * l2_reg norm(W.T, 2)^2
642
643    """
644
645    if floating is float:
646        dtype = np.float32
647    else:
648        dtype = np.float64
649
650    # get the data information into easy vars
651    cdef unsigned int n_samples = X.shape[0]
652    cdef unsigned int n_features = X.shape[1]
653    cdef unsigned int n_tasks = Y.shape[1]
654
655    # to store XtA
656    cdef floating[:, ::1] XtA = np.zeros((n_features, n_tasks), dtype=dtype)
657    cdef floating XtA_axis1norm
658    cdef floating dual_norm_XtA
659
660    # initial value of the residuals
661    cdef floating[::1, :] R = np.zeros((n_samples, n_tasks), dtype=dtype, order='F')
662
663    cdef floating[::1] norm_cols_X = np.zeros(n_features, dtype=dtype)
664    cdef floating[::1] tmp = np.zeros(n_tasks, dtype=dtype)
665    cdef floating[::1] w_ii = np.zeros(n_tasks, dtype=dtype)
666    cdef floating d_w_max
667    cdef floating w_max
668    cdef floating d_w_ii
669    cdef floating nn
670    cdef floating W_ii_abs_max
671    cdef floating gap = tol + 1.0
672    cdef floating d_w_tol = tol
673    cdef floating R_norm
674    cdef floating w_norm
675    cdef floating ry_sum
676    cdef floating l21_norm
677    cdef unsigned int ii
678    cdef unsigned int jj
679    cdef unsigned int n_iter = 0
680    cdef unsigned int f_iter
681    cdef UINT32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX)
682    cdef UINT32_t* rand_r_state = &rand_r_state_seed
683
684    cdef floating* X_ptr = &X[0, 0]
685    cdef floating* Y_ptr = &Y[0, 0]
686
687    if l1_reg == 0:
688        warnings.warn("Coordinate descent with l1_reg=0 may lead to unexpected"
689            " results and is discouraged.")
690
691    with nogil:
692        # norm_cols_X = (np.asarray(X) ** 2).sum(axis=0)
693        for ii in range(n_features):
694            norm_cols_X[ii] = _nrm2(n_samples, X_ptr + ii * n_samples, 1) ** 2
695
696        # R = Y - np.dot(X, W.T)
697        _copy(n_samples * n_tasks, Y_ptr, 1, &R[0, 0], 1)
698        for ii in range(n_features):
699            for jj in range(n_tasks):
700                if W[jj, ii] != 0:
701                    _axpy(n_samples, -W[jj, ii], X_ptr + ii * n_samples, 1,
702                          &R[0, jj], 1)
703
704        # tol = tol * linalg.norm(Y, ord='fro') ** 2
705        tol = tol * _nrm2(n_samples * n_tasks, Y_ptr, 1) ** 2
706
707        for n_iter in range(max_iter):
708            w_max = 0.0
709            d_w_max = 0.0
710            for f_iter in range(n_features):  # Loop over coordinates
711                if random:
712                    ii = rand_int(n_features, rand_r_state)
713                else:
714                    ii = f_iter
715
716                if norm_cols_X[ii] == 0.0:
717                    continue
718
719                # w_ii = W[:, ii] # Store previous value
720                _copy(n_tasks, &W[0, ii], 1, &w_ii[0], 1)
721
722                # Using Numpy:
723                # R += np.dot(X[:, ii][:, None], w_ii[None, :]) # rank 1 update
724                # Using Blas Level2:
725                # _ger(RowMajor, n_samples, n_tasks, 1.0,
726                #      &X[0, ii], 1,
727                #      &w_ii[0], 1, &R[0, 0], n_tasks)
728                # Using Blas Level1 and for loop to avoid slower threads
729                # for such small vectors
730                for jj in range(n_tasks):
731                    if w_ii[jj] != 0:
732                        _axpy(n_samples, w_ii[jj], X_ptr + ii * n_samples, 1,
733                              &R[0, jj], 1)
734
735                # Using numpy:
736                # tmp = np.dot(X[:, ii][None, :], R).ravel()
737                # Using BLAS Level 2:
738                # _gemv(RowMajor, Trans, n_samples, n_tasks, 1.0, &R[0, 0],
739                #       n_tasks, &X[0, ii], 1, 0.0, &tmp[0], 1)
740                # Using BLAS Level 1 (faster for small vectors like here):
741                for jj in range(n_tasks):
742                    tmp[jj] = _dot(n_samples, X_ptr + ii * n_samples, 1,
743                                   &R[0, jj], 1)
744
745                # nn = sqrt(np.sum(tmp ** 2))
746                nn = _nrm2(n_tasks, &tmp[0], 1)
747
748                # W[:, ii] = tmp * fmax(1. - l1_reg / nn, 0) / (norm_cols_X[ii] + l2_reg)
749                _copy(n_tasks, &tmp[0], 1, &W[0, ii], 1)
750                _scal(n_tasks, fmax(1. - l1_reg / nn, 0) / (norm_cols_X[ii] + l2_reg),
751                      &W[0, ii], 1)
752
753                # Using numpy:
754                # R -= np.dot(X[:, ii][:, None], W[:, ii][None, :])
755                # Using BLAS Level 2:
756                # Update residual : rank 1 update
757                # _ger(RowMajor, n_samples, n_tasks, -1.0,
758                #      &X[0, ii], 1, &W[0, ii], 1,
759                #      &R[0, 0], n_tasks)
760                # Using BLAS Level 1 (faster for small vectors like here):
761                for jj in range(n_tasks):
762                    if W[jj, ii] != 0:
763                        _axpy(n_samples, -W[jj, ii], X_ptr + ii * n_samples, 1,
764                              &R[0, jj], 1)
765
766                # update the maximum absolute coefficient update
767                d_w_ii = diff_abs_max(n_tasks, &W[0, ii], &w_ii[0])
768
769                if d_w_ii > d_w_max:
770                    d_w_max = d_w_ii
771
772                W_ii_abs_max = abs_max(n_tasks, &W[0, ii])
773                if W_ii_abs_max > w_max:
774                    w_max = W_ii_abs_max
775
776            if w_max == 0.0 or d_w_max / w_max < d_w_tol or n_iter == max_iter - 1:
777                # the biggest coordinate update of this iteration was smaller than
778                # the tolerance: check the duality gap as ultimate stopping
779                # criterion
780
781                # XtA = np.dot(X.T, R) - l2_reg * W.T
782                for ii in range(n_features):
783                    for jj in range(n_tasks):
784                        XtA[ii, jj] = _dot(
785                            n_samples, X_ptr + ii * n_samples, 1, &R[0, jj], 1
786                            ) - l2_reg * W[jj, ii]
787
788                # dual_norm_XtA = np.max(np.sqrt(np.sum(XtA ** 2, axis=1)))
789                dual_norm_XtA = 0.0
790                for ii in range(n_features):
791                    # np.sqrt(np.sum(XtA ** 2, axis=1))
792                    XtA_axis1norm = _nrm2(n_tasks, &XtA[ii, 0], 1)
793                    if XtA_axis1norm > dual_norm_XtA:
794                        dual_norm_XtA = XtA_axis1norm
795
796                # TODO: use squared L2 norm directly
797                # R_norm = linalg.norm(R, ord='fro')
798                # w_norm = linalg.norm(W, ord='fro')
799                R_norm = _nrm2(n_samples * n_tasks, &R[0, 0], 1)
800                w_norm = _nrm2(n_features * n_tasks, &W[0, 0], 1)
801                if (dual_norm_XtA > l1_reg):
802                    const =  l1_reg / dual_norm_XtA
803                    A_norm = R_norm * const
804                    gap = 0.5 * (R_norm ** 2 + A_norm ** 2)
805                else:
806                    const = 1.0
807                    gap = R_norm ** 2
808
809                # ry_sum = np.sum(R * y)
810                ry_sum = _dot(n_samples * n_tasks, &R[0, 0], 1, &Y[0, 0], 1)
811
812                # l21_norm = np.sqrt(np.sum(W ** 2, axis=0)).sum()
813                l21_norm = 0.0
814                for ii in range(n_features):
815                    l21_norm += _nrm2(n_tasks, &W[0, ii], 1)
816
817                gap += l1_reg * l21_norm - const * ry_sum + \
818                     0.5 * l2_reg * (1 + const ** 2) * (w_norm ** 2)
819
820                if gap < tol:
821                    # return if we reached desired tolerance
822                    break
823        else:
824            # for/else, runs if for doesn't end with a `break`
825            with gil:
826                warnings.warn("Objective did not converge. You might want to "
827                              "increase the number of iterations. Duality "
828                              "gap: {}, tolerance: {}".format(gap, tol),
829                              ConvergenceWarning)
830
831    return np.asarray(W), gap, tol, n_iter + 1
832