1# Author: Alexandre Gramfort <alexandre.gramfort@inria.fr> 2# Fabian Pedregosa <fabian.pedregosa@inria.fr> 3# Olivier Grisel <olivier.grisel@ensta.org> 4# Alexis Mignon <alexis.mignon@gmail.com> 5# Manoj Kumar <manojkumarsivaraj334@gmail.com> 6# 7# License: BSD 3 clause 8 9from libc.math cimport fabs 10cimport numpy as np 11import numpy as np 12import numpy.linalg as linalg 13 14cimport cython 15from cpython cimport bool 16from cython cimport floating 17import warnings 18from ..exceptions import ConvergenceWarning 19 20from ..utils._cython_blas cimport (_axpy, _dot, _asum, _ger, _gemv, _nrm2, 21 _copy, _scal) 22from ..utils._cython_blas cimport RowMajor, ColMajor, Trans, NoTrans 23 24 25from ..utils._random cimport our_rand_r 26 27ctypedef np.float64_t DOUBLE 28ctypedef np.uint32_t UINT32_t 29 30np.import_array() 31 32# The following two functions are shamelessly copied from the tree code. 33 34cdef enum: 35 # Max value for our rand_r replacement (near the bottom). 36 # We don't use RAND_MAX because it's different across platforms and 37 # particularly tiny on Windows/MSVC. 38 RAND_R_MAX = 0x7FFFFFFF 39 40 41cdef inline UINT32_t rand_int(UINT32_t end, UINT32_t* random_state) nogil: 42 """Generate a random integer in [0; end).""" 43 return our_rand_r(random_state) % end 44 45 46cdef inline floating fmax(floating x, floating y) nogil: 47 if x > y: 48 return x 49 return y 50 51 52cdef inline floating fsign(floating f) nogil: 53 if f == 0: 54 return 0 55 elif f > 0: 56 return 1.0 57 else: 58 return -1.0 59 60 61cdef floating abs_max(int n, floating* a) nogil: 62 """np.max(np.abs(a))""" 63 cdef int i 64 cdef floating m = fabs(a[0]) 65 cdef floating d 66 for i in range(1, n): 67 d = fabs(a[i]) 68 if d > m: 69 m = d 70 return m 71 72 73cdef floating max(int n, floating* a) nogil: 74 """np.max(a)""" 75 cdef int i 76 cdef floating m = a[0] 77 cdef floating d 78 for i in range(1, n): 79 d = a[i] 80 if d > m: 81 m = d 82 return m 83 84 85cdef floating diff_abs_max(int n, floating* a, floating* b) nogil: 86 """np.max(np.abs(a - b))""" 87 cdef int i 88 cdef floating m = fabs(a[0] - b[0]) 89 cdef floating d 90 for i in range(1, n): 91 d = fabs(a[i] - b[i]) 92 if d > m: 93 m = d 94 return m 95 96 97def enet_coordinate_descent(floating[::1] w, 98 floating alpha, floating beta, 99 floating[::1, :] X, 100 floating[::1] y, 101 int max_iter, floating tol, 102 object rng, bint random=0, bint positive=0): 103 """Cython version of the coordinate descent algorithm 104 for Elastic-Net regression 105 106 We minimize 107 108 (1/2) * norm(y - X w, 2)^2 + alpha norm(w, 1) + (beta/2) norm(w, 2)^2 109 110 """ 111 112 if floating is float: 113 dtype = np.float32 114 else: 115 dtype = np.float64 116 117 # get the data information into easy vars 118 cdef unsigned int n_samples = X.shape[0] 119 cdef unsigned int n_features = X.shape[1] 120 121 # compute norms of the columns of X 122 cdef floating[::1] norm_cols_X = np.square(X).sum(axis=0) 123 124 # initial value of the residuals 125 cdef floating[::1] R = np.empty(n_samples, dtype=dtype) 126 cdef floating[::1] XtA = np.empty(n_features, dtype=dtype) 127 128 cdef floating tmp 129 cdef floating w_ii 130 cdef floating d_w_max 131 cdef floating w_max 132 cdef floating d_w_ii 133 cdef floating gap = tol + 1.0 134 cdef floating d_w_tol = tol 135 cdef floating dual_norm_XtA 136 cdef floating R_norm2 137 cdef floating w_norm2 138 cdef floating l1_norm 139 cdef floating const 140 cdef floating A_norm2 141 cdef unsigned int ii 142 cdef unsigned int i 143 cdef unsigned int n_iter = 0 144 cdef unsigned int f_iter 145 cdef UINT32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX) 146 cdef UINT32_t* rand_r_state = &rand_r_state_seed 147 148 if alpha == 0 and beta == 0: 149 warnings.warn("Coordinate descent with no regularization may lead to " 150 "unexpected results and is discouraged.") 151 152 with nogil: 153 # R = y - np.dot(X, w) 154 _copy(n_samples, &y[0], 1, &R[0], 1) 155 _gemv(ColMajor, NoTrans, n_samples, n_features, -1.0, &X[0, 0], 156 n_samples, &w[0], 1, 1.0, &R[0], 1) 157 158 # tol *= np.dot(y, y) 159 tol *= _dot(n_samples, &y[0], 1, &y[0], 1) 160 161 for n_iter in range(max_iter): 162 w_max = 0.0 163 d_w_max = 0.0 164 for f_iter in range(n_features): # Loop over coordinates 165 if random: 166 ii = rand_int(n_features, rand_r_state) 167 else: 168 ii = f_iter 169 170 if norm_cols_X[ii] == 0.0: 171 continue 172 173 w_ii = w[ii] # Store previous value 174 175 if w_ii != 0.0: 176 # R += w_ii * X[:,ii] 177 _axpy(n_samples, w_ii, &X[0, ii], 1, &R[0], 1) 178 179 # tmp = (X[:,ii]*R).sum() 180 tmp = _dot(n_samples, &X[0, ii], 1, &R[0], 1) 181 182 if positive and tmp < 0: 183 w[ii] = 0.0 184 else: 185 w[ii] = (fsign(tmp) * fmax(fabs(tmp) - alpha, 0) 186 / (norm_cols_X[ii] + beta)) 187 188 if w[ii] != 0.0: 189 # R -= w[ii] * X[:,ii] # Update residual 190 _axpy(n_samples, -w[ii], &X[0, ii], 1, &R[0], 1) 191 192 # update the maximum absolute coefficient update 193 d_w_ii = fabs(w[ii] - w_ii) 194 d_w_max = fmax(d_w_max, d_w_ii) 195 196 w_max = fmax(w_max, fabs(w[ii])) 197 198 if (w_max == 0.0 or 199 d_w_max / w_max < d_w_tol or 200 n_iter == max_iter - 1): 201 # the biggest coordinate update of this iteration was smaller 202 # than the tolerance: check the duality gap as ultimate 203 # stopping criterion 204 205 # XtA = np.dot(X.T, R) - beta * w 206 _copy(n_features, &w[0], 1, &XtA[0], 1) 207 _gemv(ColMajor, Trans, 208 n_samples, n_features, 1.0, &X[0, 0], n_samples, 209 &R[0], 1, 210 -beta, &XtA[0], 1) 211 212 if positive: 213 dual_norm_XtA = max(n_features, &XtA[0]) 214 else: 215 dual_norm_XtA = abs_max(n_features, &XtA[0]) 216 217 # R_norm2 = np.dot(R, R) 218 R_norm2 = _dot(n_samples, &R[0], 1, &R[0], 1) 219 220 # w_norm2 = np.dot(w, w) 221 w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1) 222 223 if (dual_norm_XtA > alpha): 224 const = alpha / dual_norm_XtA 225 A_norm2 = R_norm2 * (const ** 2) 226 gap = 0.5 * (R_norm2 + A_norm2) 227 else: 228 const = 1.0 229 gap = R_norm2 230 231 l1_norm = _asum(n_features, &w[0], 1) 232 233 # np.dot(R.T, y) 234 gap += (alpha * l1_norm 235 - const * _dot(n_samples, &R[0], 1, &y[0], 1) 236 + 0.5 * beta * (1 + const ** 2) * (w_norm2)) 237 238 if gap < tol: 239 # return if we reached desired tolerance 240 break 241 242 else: 243 # for/else, runs if for doesn't end with a `break` 244 with gil: 245 message = ( 246 "Objective did not converge. You might want to increase " 247 "the number of iterations, check the scale of the " 248 "features or consider increasing regularisation. " 249 f"Duality gap: {gap:.3e}, tolerance: {tol:.3e}" 250 ) 251 if alpha < np.finfo(np.float64).eps: 252 message += ( 253 " Linear regression models with null weight for the " 254 "l1 regularization term are more efficiently fitted " 255 "using one of the solvers implemented in " 256 "sklearn.linear_model.Ridge/RidgeCV instead." 257 ) 258 warnings.warn(message, ConvergenceWarning) 259 260 return w, gap, tol, n_iter + 1 261 262 263def sparse_enet_coordinate_descent(floating [::1] w, 264 floating alpha, floating beta, 265 np.ndarray[floating, ndim=1, mode='c'] X_data, 266 np.ndarray[int, ndim=1, mode='c'] X_indices, 267 np.ndarray[int, ndim=1, mode='c'] X_indptr, 268 np.ndarray[floating, ndim=1] y, 269 floating[:] X_mean, int max_iter, 270 floating tol, object rng, bint random=0, 271 bint positive=0): 272 """Cython version of the coordinate descent algorithm for Elastic-Net 273 274 We minimize: 275 276 (1/2) * norm(y - X w, 2)^2 + alpha norm(w, 1) + (beta/2) * norm(w, 2)^2 277 278 """ 279 280 # get the data information into easy vars 281 cdef unsigned int n_samples = y.shape[0] 282 cdef unsigned int n_features = w.shape[0] 283 284 # compute norms of the columns of X 285 cdef unsigned int ii 286 cdef floating[:] norm_cols_X 287 288 cdef unsigned int startptr = X_indptr[0] 289 cdef unsigned int endptr 290 291 # initial value of the residuals 292 cdef floating[:] R = y.copy() 293 294 cdef floating[:] X_T_R 295 cdef floating[:] XtA 296 297 if floating is float: 298 dtype = np.float32 299 else: 300 dtype = np.float64 301 302 norm_cols_X = np.zeros(n_features, dtype=dtype) 303 X_T_R = np.zeros(n_features, dtype=dtype) 304 XtA = np.zeros(n_features, dtype=dtype) 305 306 cdef floating tmp 307 cdef floating w_ii 308 cdef floating d_w_max 309 cdef floating w_max 310 cdef floating d_w_ii 311 cdef floating X_mean_ii 312 cdef floating R_sum = 0.0 313 cdef floating R_norm2 314 cdef floating w_norm2 315 cdef floating A_norm2 316 cdef floating l1_norm 317 cdef floating normalize_sum 318 cdef floating gap = tol + 1.0 319 cdef floating d_w_tol = tol 320 cdef floating dual_norm_XtA 321 cdef unsigned int jj 322 cdef unsigned int n_iter = 0 323 cdef unsigned int f_iter 324 cdef UINT32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX) 325 cdef UINT32_t* rand_r_state = &rand_r_state_seed 326 cdef bint center = False 327 328 with nogil: 329 # center = (X_mean != 0).any() 330 for ii in range(n_features): 331 if X_mean[ii]: 332 center = True 333 break 334 335 for ii in range(n_features): 336 X_mean_ii = X_mean[ii] 337 endptr = X_indptr[ii + 1] 338 normalize_sum = 0.0 339 w_ii = w[ii] 340 341 for jj in range(startptr, endptr): 342 normalize_sum += (X_data[jj] - X_mean_ii) ** 2 343 R[X_indices[jj]] -= X_data[jj] * w_ii 344 norm_cols_X[ii] = normalize_sum + \ 345 (n_samples - endptr + startptr) * X_mean_ii ** 2 346 347 if center: 348 for jj in range(n_samples): 349 R[jj] += X_mean_ii * w_ii 350 startptr = endptr 351 352 # tol *= np.dot(y, y) 353 tol *= _dot(n_samples, &y[0], 1, &y[0], 1) 354 355 for n_iter in range(max_iter): 356 357 w_max = 0.0 358 d_w_max = 0.0 359 360 for f_iter in range(n_features): # Loop over coordinates 361 if random: 362 ii = rand_int(n_features, rand_r_state) 363 else: 364 ii = f_iter 365 366 if norm_cols_X[ii] == 0.0: 367 continue 368 369 startptr = X_indptr[ii] 370 endptr = X_indptr[ii + 1] 371 w_ii = w[ii] # Store previous value 372 X_mean_ii = X_mean[ii] 373 374 if w_ii != 0.0: 375 # R += w_ii * X[:,ii] 376 for jj in range(startptr, endptr): 377 R[X_indices[jj]] += X_data[jj] * w_ii 378 if center: 379 for jj in range(n_samples): 380 R[jj] -= X_mean_ii * w_ii 381 382 # tmp = (X[:,ii] * R).sum() 383 tmp = 0.0 384 for jj in range(startptr, endptr): 385 tmp += R[X_indices[jj]] * X_data[jj] 386 387 if center: 388 R_sum = 0.0 389 for jj in range(n_samples): 390 R_sum += R[jj] 391 tmp -= R_sum * X_mean_ii 392 393 if positive and tmp < 0.0: 394 w[ii] = 0.0 395 else: 396 w[ii] = fsign(tmp) * fmax(fabs(tmp) - alpha, 0) \ 397 / (norm_cols_X[ii] + beta) 398 399 if w[ii] != 0.0: 400 # R -= w[ii] * X[:,ii] # Update residual 401 for jj in range(startptr, endptr): 402 R[X_indices[jj]] -= X_data[jj] * w[ii] 403 404 if center: 405 for jj in range(n_samples): 406 R[jj] += X_mean_ii * w[ii] 407 408 # update the maximum absolute coefficient update 409 d_w_ii = fabs(w[ii] - w_ii) 410 if d_w_ii > d_w_max: 411 d_w_max = d_w_ii 412 413 if fabs(w[ii]) > w_max: 414 w_max = fabs(w[ii]) 415 416 if w_max == 0.0 or d_w_max / w_max < d_w_tol or n_iter == max_iter - 1: 417 # the biggest coordinate update of this iteration was smaller than 418 # the tolerance: check the duality gap as ultimate stopping 419 # criterion 420 421 # sparse X.T / dense R dot product 422 if center: 423 R_sum = 0.0 424 for jj in range(n_samples): 425 R_sum += R[jj] 426 427 for ii in range(n_features): 428 X_T_R[ii] = 0.0 429 for jj in range(X_indptr[ii], X_indptr[ii + 1]): 430 X_T_R[ii] += X_data[jj] * R[X_indices[jj]] 431 432 if center: 433 X_T_R[ii] -= X_mean[ii] * R_sum 434 XtA[ii] = X_T_R[ii] - beta * w[ii] 435 436 if positive: 437 dual_norm_XtA = max(n_features, &XtA[0]) 438 else: 439 dual_norm_XtA = abs_max(n_features, &XtA[0]) 440 441 # R_norm2 = np.dot(R, R) 442 R_norm2 = _dot(n_samples, &R[0], 1, &R[0], 1) 443 444 # w_norm2 = np.dot(w, w) 445 w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1) 446 if (dual_norm_XtA > alpha): 447 const = alpha / dual_norm_XtA 448 A_norm2 = R_norm2 * const**2 449 gap = 0.5 * (R_norm2 + A_norm2) 450 else: 451 const = 1.0 452 gap = R_norm2 453 454 l1_norm = _asum(n_features, &w[0], 1) 455 456 gap += (alpha * l1_norm - const * _dot( 457 n_samples, 458 &R[0], 1, 459 &y[0], 1 460 ) 461 + 0.5 * beta * (1 + const ** 2) * w_norm2) 462 463 if gap < tol: 464 # return if we reached desired tolerance 465 break 466 467 else: 468 # for/else, runs if for doesn't end with a `break` 469 with gil: 470 warnings.warn("Objective did not converge. You might want to " 471 "increase the number of iterations. Duality " 472 "gap: {}, tolerance: {}".format(gap, tol), 473 ConvergenceWarning) 474 475 return w, gap, tol, n_iter + 1 476 477 478def enet_coordinate_descent_gram(floating[::1] w, 479 floating alpha, floating beta, 480 np.ndarray[floating, ndim=2, mode='c'] Q, 481 np.ndarray[floating, ndim=1, mode='c'] q, 482 np.ndarray[floating, ndim=1] y, 483 int max_iter, floating tol, object rng, 484 bint random=0, bint positive=0): 485 """Cython version of the coordinate descent algorithm 486 for Elastic-Net regression 487 488 We minimize 489 490 (1/2) * w^T Q w - q^T w + alpha norm(w, 1) + (beta/2) * norm(w, 2)^2 491 492 which amount to the Elastic-Net problem when: 493 Q = X^T X (Gram matrix) 494 q = X^T y 495 """ 496 497 if floating is float: 498 dtype = np.float32 499 else: 500 dtype = np.float64 501 502 # get the data information into easy vars 503 cdef unsigned int n_samples = y.shape[0] 504 cdef unsigned int n_features = Q.shape[0] 505 506 # initial value "Q w" which will be kept of up to date in the iterations 507 cdef floating[:] H = np.dot(Q, w) 508 509 cdef floating[:] XtA = np.zeros(n_features, dtype=dtype) 510 cdef floating tmp 511 cdef floating w_ii 512 cdef floating d_w_max 513 cdef floating w_max 514 cdef floating d_w_ii 515 cdef floating q_dot_w 516 cdef floating w_norm2 517 cdef floating gap = tol + 1.0 518 cdef floating d_w_tol = tol 519 cdef floating dual_norm_XtA 520 cdef unsigned int ii 521 cdef unsigned int n_iter = 0 522 cdef unsigned int f_iter 523 cdef UINT32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX) 524 cdef UINT32_t* rand_r_state = &rand_r_state_seed 525 526 cdef floating y_norm2 = np.dot(y, y) 527 cdef floating* w_ptr = <floating*>&w[0] 528 cdef floating* Q_ptr = &Q[0, 0] 529 cdef floating* q_ptr = <floating*>q.data 530 cdef floating* H_ptr = &H[0] 531 cdef floating* XtA_ptr = &XtA[0] 532 tol = tol * y_norm2 533 534 if alpha == 0: 535 warnings.warn("Coordinate descent with alpha=0 may lead to unexpected" 536 " results and is discouraged.") 537 538 with nogil: 539 for n_iter in range(max_iter): 540 w_max = 0.0 541 d_w_max = 0.0 542 for f_iter in range(n_features): # Loop over coordinates 543 if random: 544 ii = rand_int(n_features, rand_r_state) 545 else: 546 ii = f_iter 547 548 if Q[ii, ii] == 0.0: 549 continue 550 551 w_ii = w[ii] # Store previous value 552 553 if w_ii != 0.0: 554 # H -= w_ii * Q[ii] 555 _axpy(n_features, -w_ii, Q_ptr + ii * n_features, 1, 556 H_ptr, 1) 557 558 tmp = q[ii] - H[ii] 559 560 if positive and tmp < 0: 561 w[ii] = 0.0 562 else: 563 w[ii] = fsign(tmp) * fmax(fabs(tmp) - alpha, 0) \ 564 / (Q[ii, ii] + beta) 565 566 if w[ii] != 0.0: 567 # H += w[ii] * Q[ii] # Update H = X.T X w 568 _axpy(n_features, w[ii], Q_ptr + ii * n_features, 1, 569 H_ptr, 1) 570 571 # update the maximum absolute coefficient update 572 d_w_ii = fabs(w[ii] - w_ii) 573 if d_w_ii > d_w_max: 574 d_w_max = d_w_ii 575 576 if fabs(w[ii]) > w_max: 577 w_max = fabs(w[ii]) 578 579 if w_max == 0.0 or d_w_max / w_max < d_w_tol or n_iter == max_iter - 1: 580 # the biggest coordinate update of this iteration was smaller than 581 # the tolerance: check the duality gap as ultimate stopping 582 # criterion 583 584 # q_dot_w = np.dot(w, q) 585 q_dot_w = _dot(n_features, w_ptr, 1, q_ptr, 1) 586 587 for ii in range(n_features): 588 XtA[ii] = q[ii] - H[ii] - beta * w[ii] 589 if positive: 590 dual_norm_XtA = max(n_features, XtA_ptr) 591 else: 592 dual_norm_XtA = abs_max(n_features, XtA_ptr) 593 594 # temp = np.sum(w * H) 595 tmp = 0.0 596 for ii in range(n_features): 597 tmp += w[ii] * H[ii] 598 R_norm2 = y_norm2 + tmp - 2.0 * q_dot_w 599 600 # w_norm2 = np.dot(w, w) 601 w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1) 602 603 if (dual_norm_XtA > alpha): 604 const = alpha / dual_norm_XtA 605 A_norm2 = R_norm2 * (const ** 2) 606 gap = 0.5 * (R_norm2 + A_norm2) 607 else: 608 const = 1.0 609 gap = R_norm2 610 611 # The call to asum is equivalent to the L1 norm of w 612 gap += (alpha * _asum(n_features, &w[0], 1) - 613 const * y_norm2 + const * q_dot_w + 614 0.5 * beta * (1 + const ** 2) * w_norm2) 615 616 if gap < tol: 617 # return if we reached desired tolerance 618 break 619 620 else: 621 # for/else, runs if for doesn't end with a `break` 622 with gil: 623 warnings.warn("Objective did not converge. You might want to " 624 "increase the number of iterations. Duality " 625 "gap: {}, tolerance: {}".format(gap, tol), 626 ConvergenceWarning) 627 628 return np.asarray(w), gap, tol, n_iter + 1 629 630 631def enet_coordinate_descent_multi_task( 632 floating[::1, :] W, floating l1_reg, floating l2_reg, 633 np.ndarray[floating, ndim=2, mode='fortran'] X, # TODO: use views with Cython 3.0 634 np.ndarray[floating, ndim=2, mode='fortran'] Y, # hopefully with skl 1.0 635 int max_iter, floating tol, object rng, bint random=0): 636 """Cython version of the coordinate descent algorithm 637 for Elastic-Net mult-task regression 638 639 We minimize 640 641 0.5 * norm(Y - X W.T, 2)^2 + l1_reg ||W.T||_21 + 0.5 * l2_reg norm(W.T, 2)^2 642 643 """ 644 645 if floating is float: 646 dtype = np.float32 647 else: 648 dtype = np.float64 649 650 # get the data information into easy vars 651 cdef unsigned int n_samples = X.shape[0] 652 cdef unsigned int n_features = X.shape[1] 653 cdef unsigned int n_tasks = Y.shape[1] 654 655 # to store XtA 656 cdef floating[:, ::1] XtA = np.zeros((n_features, n_tasks), dtype=dtype) 657 cdef floating XtA_axis1norm 658 cdef floating dual_norm_XtA 659 660 # initial value of the residuals 661 cdef floating[::1, :] R = np.zeros((n_samples, n_tasks), dtype=dtype, order='F') 662 663 cdef floating[::1] norm_cols_X = np.zeros(n_features, dtype=dtype) 664 cdef floating[::1] tmp = np.zeros(n_tasks, dtype=dtype) 665 cdef floating[::1] w_ii = np.zeros(n_tasks, dtype=dtype) 666 cdef floating d_w_max 667 cdef floating w_max 668 cdef floating d_w_ii 669 cdef floating nn 670 cdef floating W_ii_abs_max 671 cdef floating gap = tol + 1.0 672 cdef floating d_w_tol = tol 673 cdef floating R_norm 674 cdef floating w_norm 675 cdef floating ry_sum 676 cdef floating l21_norm 677 cdef unsigned int ii 678 cdef unsigned int jj 679 cdef unsigned int n_iter = 0 680 cdef unsigned int f_iter 681 cdef UINT32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX) 682 cdef UINT32_t* rand_r_state = &rand_r_state_seed 683 684 cdef floating* X_ptr = &X[0, 0] 685 cdef floating* Y_ptr = &Y[0, 0] 686 687 if l1_reg == 0: 688 warnings.warn("Coordinate descent with l1_reg=0 may lead to unexpected" 689 " results and is discouraged.") 690 691 with nogil: 692 # norm_cols_X = (np.asarray(X) ** 2).sum(axis=0) 693 for ii in range(n_features): 694 norm_cols_X[ii] = _nrm2(n_samples, X_ptr + ii * n_samples, 1) ** 2 695 696 # R = Y - np.dot(X, W.T) 697 _copy(n_samples * n_tasks, Y_ptr, 1, &R[0, 0], 1) 698 for ii in range(n_features): 699 for jj in range(n_tasks): 700 if W[jj, ii] != 0: 701 _axpy(n_samples, -W[jj, ii], X_ptr + ii * n_samples, 1, 702 &R[0, jj], 1) 703 704 # tol = tol * linalg.norm(Y, ord='fro') ** 2 705 tol = tol * _nrm2(n_samples * n_tasks, Y_ptr, 1) ** 2 706 707 for n_iter in range(max_iter): 708 w_max = 0.0 709 d_w_max = 0.0 710 for f_iter in range(n_features): # Loop over coordinates 711 if random: 712 ii = rand_int(n_features, rand_r_state) 713 else: 714 ii = f_iter 715 716 if norm_cols_X[ii] == 0.0: 717 continue 718 719 # w_ii = W[:, ii] # Store previous value 720 _copy(n_tasks, &W[0, ii], 1, &w_ii[0], 1) 721 722 # Using Numpy: 723 # R += np.dot(X[:, ii][:, None], w_ii[None, :]) # rank 1 update 724 # Using Blas Level2: 725 # _ger(RowMajor, n_samples, n_tasks, 1.0, 726 # &X[0, ii], 1, 727 # &w_ii[0], 1, &R[0, 0], n_tasks) 728 # Using Blas Level1 and for loop to avoid slower threads 729 # for such small vectors 730 for jj in range(n_tasks): 731 if w_ii[jj] != 0: 732 _axpy(n_samples, w_ii[jj], X_ptr + ii * n_samples, 1, 733 &R[0, jj], 1) 734 735 # Using numpy: 736 # tmp = np.dot(X[:, ii][None, :], R).ravel() 737 # Using BLAS Level 2: 738 # _gemv(RowMajor, Trans, n_samples, n_tasks, 1.0, &R[0, 0], 739 # n_tasks, &X[0, ii], 1, 0.0, &tmp[0], 1) 740 # Using BLAS Level 1 (faster for small vectors like here): 741 for jj in range(n_tasks): 742 tmp[jj] = _dot(n_samples, X_ptr + ii * n_samples, 1, 743 &R[0, jj], 1) 744 745 # nn = sqrt(np.sum(tmp ** 2)) 746 nn = _nrm2(n_tasks, &tmp[0], 1) 747 748 # W[:, ii] = tmp * fmax(1. - l1_reg / nn, 0) / (norm_cols_X[ii] + l2_reg) 749 _copy(n_tasks, &tmp[0], 1, &W[0, ii], 1) 750 _scal(n_tasks, fmax(1. - l1_reg / nn, 0) / (norm_cols_X[ii] + l2_reg), 751 &W[0, ii], 1) 752 753 # Using numpy: 754 # R -= np.dot(X[:, ii][:, None], W[:, ii][None, :]) 755 # Using BLAS Level 2: 756 # Update residual : rank 1 update 757 # _ger(RowMajor, n_samples, n_tasks, -1.0, 758 # &X[0, ii], 1, &W[0, ii], 1, 759 # &R[0, 0], n_tasks) 760 # Using BLAS Level 1 (faster for small vectors like here): 761 for jj in range(n_tasks): 762 if W[jj, ii] != 0: 763 _axpy(n_samples, -W[jj, ii], X_ptr + ii * n_samples, 1, 764 &R[0, jj], 1) 765 766 # update the maximum absolute coefficient update 767 d_w_ii = diff_abs_max(n_tasks, &W[0, ii], &w_ii[0]) 768 769 if d_w_ii > d_w_max: 770 d_w_max = d_w_ii 771 772 W_ii_abs_max = abs_max(n_tasks, &W[0, ii]) 773 if W_ii_abs_max > w_max: 774 w_max = W_ii_abs_max 775 776 if w_max == 0.0 or d_w_max / w_max < d_w_tol or n_iter == max_iter - 1: 777 # the biggest coordinate update of this iteration was smaller than 778 # the tolerance: check the duality gap as ultimate stopping 779 # criterion 780 781 # XtA = np.dot(X.T, R) - l2_reg * W.T 782 for ii in range(n_features): 783 for jj in range(n_tasks): 784 XtA[ii, jj] = _dot( 785 n_samples, X_ptr + ii * n_samples, 1, &R[0, jj], 1 786 ) - l2_reg * W[jj, ii] 787 788 # dual_norm_XtA = np.max(np.sqrt(np.sum(XtA ** 2, axis=1))) 789 dual_norm_XtA = 0.0 790 for ii in range(n_features): 791 # np.sqrt(np.sum(XtA ** 2, axis=1)) 792 XtA_axis1norm = _nrm2(n_tasks, &XtA[ii, 0], 1) 793 if XtA_axis1norm > dual_norm_XtA: 794 dual_norm_XtA = XtA_axis1norm 795 796 # TODO: use squared L2 norm directly 797 # R_norm = linalg.norm(R, ord='fro') 798 # w_norm = linalg.norm(W, ord='fro') 799 R_norm = _nrm2(n_samples * n_tasks, &R[0, 0], 1) 800 w_norm = _nrm2(n_features * n_tasks, &W[0, 0], 1) 801 if (dual_norm_XtA > l1_reg): 802 const = l1_reg / dual_norm_XtA 803 A_norm = R_norm * const 804 gap = 0.5 * (R_norm ** 2 + A_norm ** 2) 805 else: 806 const = 1.0 807 gap = R_norm ** 2 808 809 # ry_sum = np.sum(R * y) 810 ry_sum = _dot(n_samples * n_tasks, &R[0, 0], 1, &Y[0, 0], 1) 811 812 # l21_norm = np.sqrt(np.sum(W ** 2, axis=0)).sum() 813 l21_norm = 0.0 814 for ii in range(n_features): 815 l21_norm += _nrm2(n_tasks, &W[0, ii], 1) 816 817 gap += l1_reg * l21_norm - const * ry_sum + \ 818 0.5 * l2_reg * (1 + const ** 2) * (w_norm ** 2) 819 820 if gap < tol: 821 # return if we reached desired tolerance 822 break 823 else: 824 # for/else, runs if for doesn't end with a `break` 825 with gil: 826 warnings.warn("Objective did not converge. You might want to " 827 "increase the number of iterations. Duality " 828 "gap: {}, tolerance: {}".format(gap, tol), 829 ConvergenceWarning) 830 831 return np.asarray(W), gap, tol, n_iter + 1 832