1 /*
2  * irlb: Implicitly restarted Lanczos bidiagonalization partial SVD.
3  * Copyright (c) 2016 by Bryan W. Lewis
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9 
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  * GNU General Public License for more details.
14 
15  * You should have received a copy of the GNU General Public License
16  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17  */
18 
19 #include <stdlib.h>
20 #include <string.h>
21 #include <fcntl.h>
22 #include <assert.h>
23 #include <math.h>
24 
25 #define USE_FC_LEN_T
26 #include <Rconfig.h>
27 #include "R_ext/BLAS.h"
28 #ifndef FCONE
29 # define FCONE
30 #endif
31 #include <R.h>
32 #define USE_RINTERNALS
33 #include <Rinternals.h>
34 #include <Rdefines.h>
35 
36 #include "R_ext/Lapack.h"
37 #include "R_ext/Rdynload.h"
38 #include "R_ext/Utils.h"
39 #include "R_ext/Parse.h"
40 
41 #include "Matrix.h"
42 #include "Matrix_stubs.c"
43 
44 #include "irlb.h"
45 
46 /* helper function for calling rnorm below */
47 SEXP
RNORM(int n)48 RNORM (int n)
49 {
50   char buf[4096];
51   SEXP cmdSexp, cmdexpr, ans = R_NilValue;
52   ParseStatus status;
53   cmdSexp = PROTECT (allocVector (STRSXP, 1));
54   snprintf (buf, 4095, "rnorm(%d)", n);
55   SET_STRING_ELT (cmdSexp, 0, mkChar (buf));
56   cmdexpr = PROTECT (R_ParseVector (cmdSexp, -1, &status, R_NilValue));
57   if (status != PARSE_OK)
58     {
59       UNPROTECT (2);
60       error ("invalid call");
61     }
62   for (int i = 0; i < length (cmdexpr); i++)
63     {
64       ans = PROTECT (eval (VECTOR_ELT (cmdexpr, i), R_GlobalEnv));
65       UNPROTECT (1);
66     }
67   UNPROTECT (2);
68   return ans;
69 }
70 
71 /* irlb C implementation wrapper for R
72  *
73  * X double precision input matrix
74  * NU integer number of singular values/vectors to compute must be > 3
75  * INIT double precision starting vector length(INIT) must equal ncol(X)
76  * WORK integer working subspace dimension must be > NU
77  * MAXIT integer maximum number of iterations
78  * TOL double tolerance
79  * EPS double invariant subspace detection tolerance
80  * MULT integer 0 X is a dense matrix (dgemm), 1 sparse (cholmod)
81  * RESTART integer 0 no or > 0 indicates restart of dimension n
82  * RV, RW, RS optional restart V W and S values of dimension RESTART
83  *    (only used when RESTART > 0)
84  * SCALE either NULL (no scaling) or a vector of length ncol(X)
85  * SHIFT either NULL (no shift) or a single double-precision number
86  * CENTER either NULL (no centering) or a vector of length ncol(X)
87  * SVTOL double tolerance max allowed per cent change in each estimated singular value
88  *
89  * Returns a list with 6 elements:
90  * 1. vector of estimated singular values
91  * 2. matrix of estimated left singular vectors
92  * 3. matrix of estimated right singular vectors
93  * 4. number of algorithm iterations
94  * 5. number of matrix vector products
95  * 6. irlb C algorithm return error code (see irlb below)
96  */
97 SEXP
IRLB(SEXP X,SEXP NU,SEXP INIT,SEXP WORK,SEXP MAXIT,SEXP TOL,SEXP EPS,SEXP MULT,SEXP RESTART,SEXP RV,SEXP RW,SEXP RS,SEXP SCALE,SEXP SHIFT,SEXP CENTER,SEXP SVTOL)98 IRLB (SEXP X, SEXP NU, SEXP INIT, SEXP WORK, SEXP MAXIT, SEXP TOL, SEXP EPS,
99       SEXP MULT, SEXP RESTART, SEXP RV, SEXP RW, SEXP RS, SEXP SCALE,
100       SEXP SHIFT, SEXP CENTER, SEXP SVTOL)
101 {
102   SEXP ANS, S, U, V;
103   double *V1, *U1, *W, *F, *B, *BU, *BV, *BS, *BW, *res, *T, *scale, *shift,
104     *center, *SVRATIO;
105   int i, iter, mprod, ret;
106   int m, n;
107 
108   int mult = INTEGER (MULT)[0];
109   void *AS = NULL;
110   double *A = NULL;
111   switch (mult)
112     {
113     case 1:
114       AS = (void *) AS_CHM_SP (X);
115       int *dims = INTEGER (GET_SLOT (X, install ("Dim")));
116       m = dims[0];
117       n = dims[1];
118       break;
119     default:
120       A = REAL (X);
121       m = nrows (X);
122       n = ncols (X);
123     }
124   int nu = INTEGER (NU)[0];
125   int work = INTEGER (WORK)[0];
126   int maxit = INTEGER (MAXIT)[0];
127   double tol = REAL (TOL)[0];
128   double svtol = REAL (SVTOL)[0];
129   int lwork = 7 * work * (1 + work);
130   int restart = INTEGER (RESTART)[0];
131   double eps = REAL (EPS)[0];
132 
133   PROTECT (ANS = NEW_LIST (6));
134   PROTECT (S = allocVector (REALSXP, nu));
135   PROTECT (U = allocVector (REALSXP, m * work));
136   PROTECT (V = allocVector (REALSXP, n * work));
137   if (restart == 0)
138     for (i = 0; i < n; ++i)
139       (REAL (V))[i] = (REAL (INIT))[i];
140 
141   /* set up intermediate working storage */
142   scale = NULL;
143   shift = NULL;
144   center = NULL;
145   if (TYPEOF (SCALE) == REALSXP)
146     {
147       scale = (double *) R_alloc (n * 2, sizeof (double));
148       memcpy (scale, REAL (SCALE), n * sizeof (double));
149     }
150   if (TYPEOF (SHIFT) == REALSXP)
151     {
152       shift = REAL (SHIFT);
153     }
154   if (TYPEOF (CENTER) == REALSXP)
155     {
156       center = REAL (CENTER);
157     }
158   SVRATIO = (double *) R_alloc (work, sizeof (double));
159   V1 = (double *) R_alloc (n * work, sizeof (double));
160   U1 = (double *) R_alloc (m * work, sizeof (double));
161   W = (double *) R_alloc (m * work, sizeof (double));
162   F = (double *) R_alloc (n, sizeof (double));
163   B = (double *) R_alloc (work * work, sizeof (double));
164   BU = (double *) R_alloc (work * work, sizeof (double));
165   BV = (double *) R_alloc (work * work, sizeof (double));
166   BS = (double *) R_alloc (work, sizeof (double));
167   BW = (double *) R_alloc (lwork, sizeof (double));
168   res = (double *) R_alloc (work, sizeof (double));
169   T = (double *) R_alloc (lwork, sizeof (double));
170   if (restart > 0)
171     {
172       memcpy (REAL (V), REAL (RV), n * (restart + 1) * sizeof (double));
173       memcpy (W, REAL (RW), m * restart * sizeof (double));
174       memset (B, 0, work * work * sizeof (double));
175       for (i = 0; i < restart; ++i)
176         B[i + work * i] = REAL (RS)[i];
177     }
178   ret =
179     irlb (A, AS, mult, m, n, nu, work, maxit, restart, tol, scale, shift, center,
180           REAL (S), REAL (U), REAL (V), &iter, &mprod, eps, lwork, V1, U1, W,
181           F, B, BU, BV, BS, BW, res, T, svtol, SVRATIO);
182   SET_VECTOR_ELT (ANS, 0, S);
183   SET_VECTOR_ELT (ANS, 1, U);
184   SET_VECTOR_ELT (ANS, 2, V);
185   SET_VECTOR_ELT (ANS, 3, ScalarInteger (iter));
186   SET_VECTOR_ELT (ANS, 4, ScalarInteger (mprod));
187   SET_VECTOR_ELT (ANS, 5, ScalarInteger (ret));
188   UNPROTECT (4);
189   return ANS;
190 }
191 
192 /* irlb: main computation function.
193  * returns:
194  *  0 on success,
195  * -1 invalid dimensions,
196  * -2 not converged
197  * -3 out of memory
198  * -4 starting vector near the null space of A
199  *
200  * all data must be allocated by caller, required sizes listed below
201  */
202 int
irlb(double * A,void * AS,int mult,int m,int n,int nu,int work,int maxit,int restart,double tol,double * scale,double * shift,double * center,double * s,double * U,double * V,int * ITER,int * MPROD,double eps,int lwork,double * V1,double * U1,double * W,double * F,double * B,double * BU,double * BV,double * BS,double * BW,double * res,double * T,double svtol,double * svratio)203 irlb (double *A,                // Input data matrix (double case)
204       void *AS,                 // input data matrix (sparse case)
205       int mult,                 // 0 -> use double *A, 1 -> use AS
206       int m,                    // data matrix number of rows, must be > 3.
207       int n,                    // data matrix number of columns, must be > 3.
208       int nu,                   // dimension of solution
209       int work,                 // working dimension, must be > 3.
210       int maxit,                // maximum number of main iterations
211       int restart,              // 0->no, n>0 -> restarted algorithm of dimension n
212       double tol,               // convergence tolerance
213       double *scale,            // optional scale (NULL for no scale) size n * 2
214       double *shift,            // optional shift (NULL for no shift)
215       double *center,           // optional center (NULL for no center)
216       // output values
217       double *s,                // output singular values at least length nu
218       double *U,                // output left singular vectors  m x work
219       double *V,                // output right singular vectors n x work
220       int *ITER,                // ouput number of Lanczos iterations
221       int *MPROD,               // output number of matrix vector products
222       double eps,               // tolerance for invariant subspace detection
223       // working intermediate storage, sizes shown
224       int lwork, double *V1,    // n x work
225       double *U1,               // m x work
226       double *W,                // m x work  input when restart > 0
227       double *F,                // n
228       double *B,                // work x work  input when restart > 0
229       double *BU,               // work x work
230       double *BV,               // work x work
231       double *BS,               // work
232       double *BW,               // lwork
233       double *res,              // work
234       double *T,                // lwork
235       double svtol,             // svtol limit
236       double *svratio)          // convtest extra storage vector of length work
237 {
238   double d, S, R, alpha, beta, R_F, SS;
239   double *x;
240   int jj, kk;
241   int converged;
242   int info, j, k = restart;
243   int inc = 1;
244   int mprod = 0;
245   int iter = 0;
246   double Smax = 0;
247   SEXP FOO;
248 
249 /* Check for valid input dimensions */
250   if (work < 4 || n < 4 || m < 4)
251     return -1;
252 
253   if (restart == 0)
254     memset (B, 0, work * work * sizeof (double));
255   memset(svratio, 0, work * sizeof(double));
256 
257 /* Main iteration */
258   while (iter < maxit)
259     {
260       j = 0;
261 /*  Normalize starting vector */
262       if (iter == 0 && restart == 0)
263         {
264           d = F77_CALL (dnrm2) (&n, V, &inc);
265           if (d < eps)
266             return -1;
267           d = 1 / d;
268           F77_CALL (dscal) (&n, &d, V, &inc);
269         }
270       else
271         j = k;
272 
273 /* optionally apply scale */
274       x = V + j * n;
275       if (scale)
276         {
277           x = scale + n;
278           memcpy (scale + n, V + j * n, n * sizeof (double));
279           for (kk = 0; kk < n; ++kk)
280             x[kk] = x[kk] / scale[kk];
281         }
282 
283       switch (mult)
284         {
285         case 1:
286           dsdmult ('n', m, n, AS, x, W + j * m);
287           break;
288         default:
289           alpha = 1;
290           beta = 0;
291           F77_CALL (dgemv) ("n", &m, &n, &alpha, (double *) A, &m, x,
292                             &inc, &beta, W + j * m, &inc FCONE);
293         }
294       mprod++;
295       R_CheckUserInterrupt ();
296 /* optionally apply shift in square cases m = n */
297       if (shift)
298         {
299           jj = j * m;
300           for (kk = 0; kk < m; ++kk)
301             W[jj + kk] = W[jj + kk] + shift[0] * x[kk];
302         }
303 /* optionally apply centering */
304       if (center)
305         {
306           jj = j * m;
307           beta = F77_CALL (ddot) (&n, x, &inc, center, &inc);
308           for (kk = 0; kk < m; ++kk)
309             W[jj + kk] = W[jj + kk] - beta;
310         }
311       if (iter > 0)
312         orthog (W, W + j * m, T, m, j, 1);
313       S = F77_CALL (dnrm2) (&m, W + j * m, &inc);
314       if (S < eps && j == 0)
315         return -4;
316       SS = 1.0 / S;
317       F77_CALL (dscal) (&m, &SS, W + j * m, &inc);
318 
319 /* The Lanczos process */
320       while (j < work)
321         {
322           switch (mult)
323             {
324             case 1:
325               dsdmult ('t', m, n, AS, W + j * m, F);
326               break;
327             default:
328               alpha = 1.0;
329               beta = 0.0;
330               F77_CALL (dgemv) ("t", &m, &n, &alpha, (double *) A, &m,
331                                 W + j * m, &inc, &beta, F, &inc FCONE);
332             }
333           mprod++;
334           R_CheckUserInterrupt ();
335 /* optionally apply shift, scale, center */
336           if (shift)
337             {
338               // Note, not a bug because shift only applies to square matrices
339               for (kk = 0; kk < m; ++kk)
340                 F[kk] = F[kk] + shift[0] * W[j * m + kk];
341             }
342           if (scale)
343             {
344               for (kk = 0; kk < n; ++kk)
345                 F[kk] = F[kk] / scale[kk];
346             }
347           if (center)
348             {
349               beta = 0;
350               for (kk = 0; kk < m; ++kk) beta += W[j *m + kk];
351               if (scale)
352                 for (kk = 0; kk < n; ++kk)
353                   F[kk] = F[kk] - beta * center[kk] / scale[kk];
354               else
355                 for (kk = 0; kk < n; ++kk)
356                   F[kk] = F[kk] - beta * center[kk];
357             }
358           SS = -S;
359           F77_CALL (daxpy) (&n, &SS, V + j * n, &inc, F, &inc);
360           orthog (V, F, T, n, j + 1, 1);
361 
362           if (j + 1 < work)
363             {
364               R_F = F77_CALL (dnrm2) (&n, F, &inc);
365               R = 1.0 / R_F;
366               if (R_F < eps)        // near invariant subspace
367                 {
368                   FOO = RNORM (n);
369                   for (kk = 0; kk < n; ++kk)
370                     F[kk] = REAL (FOO)[kk];
371                   orthog (V, F, T, n, j + 1, 1);
372                   R_F = F77_CALL (dnrm2) (&n, F, &inc);
373                   R = 1.0 / R_F;
374                   R_F = 0;
375                 }
376               memmove (V + (j + 1) * n, F, n * sizeof (double));
377               F77_CALL (dscal) (&n, &R, V + (j + 1) * n, &inc);
378               B[j * work + j] = S;
379               B[(j + 1) * work + j] = R_F;
380 /* optionally apply scale */
381               x = V + (j + 1) * n;
382               if (scale)
383                 {
384                   x = scale + n;
385                   memcpy (x, V + (j + 1) * n, n * sizeof (double));
386                   for (kk = 0; kk < n; ++kk)
387                     x[kk] = x[kk] / scale[kk];
388                 }
389               switch (mult)
390                 {
391                 case 1:
392                   dsdmult ('n', m, n, AS, x, W + (j + 1) * m);
393                   break;
394                 default:
395                   alpha = 1.0;
396                   beta = 0.0;
397                   F77_CALL (dgemv) ("n", &m, &n, &alpha, (double *) A, &m,
398                                     x, &inc, &beta, W + (j + 1) * m, &inc FCONE);
399                 }
400               mprod++;
401               R_CheckUserInterrupt ();
402 /* optionally apply shift */
403               if (shift)
404                 {
405                   jj = j + 1;
406                   for (kk = 0; kk < m; ++kk)
407                     W[jj * m + kk] = W[jj * m + kk] + shift[0] * x[kk];
408                 }
409 /* optionally apply centering */
410               if (center)
411                 {
412                   jj = (j + 1) * m;
413                   beta = F77_CALL (ddot) (&n, x, &inc, center, &inc);
414                   for (kk = 0; kk < m; ++kk)
415                     W[jj + kk] = W[jj + kk] - beta;
416                 }
417 /* One step of classical Gram-Schmidt */
418               R = -R_F;
419               F77_CALL (daxpy) (&m, &R, W + j * m, &inc, W + (j + 1) * m,
420                                 &inc);
421 /* full re-orthogonalization of W_{j+1} */
422               orthog (W, W + (j + 1) * m, T, m, j + 1, 1);
423               S = F77_CALL (dnrm2) (&m, W + (j + 1) * m, &inc);
424               SS = 1.0 / S;
425               if (S < eps)
426                 {
427                   FOO = RNORM (m);
428                   jj = (j + 1) * m;
429                   for (kk = 0; kk < m; ++kk)
430                     W[jj + kk] = REAL (FOO)[kk];
431                   orthog (W, W + (j + 1) * m, T, m, j + 1, 1);
432                   S = F77_CALL (dnrm2) (&m, W + (j + 1) * m, &inc);
433                   SS = 1.0 / S;
434                   F77_CALL (dscal) (&m, &SS, W + (j + 1) * m, &inc);
435                   S = 0;
436                 }
437               else
438                 F77_CALL (dscal) (&m, &SS, W + (j + 1) * m, &inc);
439             }
440           else
441             {
442               B[j * work + j] = S;
443             }
444           j++;
445         }
446 
447       memmove (BU, B, work * work * sizeof (double));   // Make a working copy of B
448       int *BI = (int *) T;
449       F77_CALL (dgesdd) ("O", &work, &work, BU, &work, BS, BU, &work, BV,
450                          &work, BW, &lwork, BI, &info FCONE);
451       R_F = F77_CALL (dnrm2) (&n, F, &inc);
452       R = 1.0 / R_F;
453       F77_CALL (dscal) (&n, &R, F, &inc);
454 /* Force termination after encountering linear dependence */
455       if (R_F < eps)
456         R_F = 0;
457 
458       Smax = 0;
459       for (jj = 0; jj < j; ++jj)
460         {
461           if (BS[jj] > Smax)
462             Smax = BS[jj];
463           svratio[jj] = fabs (svratio[jj] - BS[jj]) / BS[jj];
464         }
465       for (kk = 0; kk < j; ++kk)
466         res[kk] = R_F * BU[kk * work + (j - 1)];
467 /* Update k to be the number of converged singular values. */
468       convtests (j, nu, tol, svtol, Smax, svratio, res, &k, &converged, S);
469 
470       if (converged == 1)
471         {
472           iter++;
473           break;
474         }
475       for (jj = 0; jj < j; ++jj)
476         svratio[jj] = BS[jj];
477 
478       alpha = 1;
479       beta = 0;
480       F77_CALL (dgemm) ("n", "t", &n, &k, &j, &alpha, V, &n, BV, &work, &beta,
481                         V1, &n FCONE FCONE);
482       memmove (V, V1, n * k * sizeof (double));
483       memmove (V + n * k, F, n * sizeof (double));
484 
485       memset (B, 0, work * work * sizeof (double));
486       for (jj = 0; jj < k; ++jj)
487         {
488           B[jj * work + jj] = BS[jj];
489           B[k * work + jj] = res[jj];
490         }
491 
492 /*   Update the left approximate singular vectors */
493       alpha = 1;
494       beta = 0;
495       F77_CALL (dgemm) ("n", "n", &m, &k, &j, &alpha, W, &m, BU, &work, &beta,
496                         U1, &m FCONE FCONE);
497       memmove (W, U1, m * k * sizeof (double));
498       iter++;
499     }
500 
501 /* Results */
502   memmove (s, BS, nu * sizeof (double));        /* Singular values */
503   alpha = 1;
504   beta = 0;
505   F77_CALL (dgemm) ("n", "n", &m, &nu, &work, &alpha, W, &m, BU, &work, &beta,
506                     U, &m FCONE FCONE);
507   F77_CALL (dgemm) ("n", "t", &n, &nu, &work, &alpha, V, &n, BV, &work, &beta,
508                     V1, &n FCONE FCONE);
509   memmove (V, V1, n * nu * sizeof (double));
510 
511   *ITER = iter;
512   *MPROD = mprod;
513   return (converged == 1 ? 0 : -2);
514 }
515 
516 
517 cholmod_common chol_c;
518 /* Need our own CHOLMOD error handler */
519 void attribute_hidden
irlba_R_cholmod_error(int status,const char * file,int line,const char * message)520 irlba_R_cholmod_error (int status, const char *file, int line,
521                        const char *message)
522 {
523   if (status < 0)
524     error ("Cholmod error '%s' at file:%s, line %d", message, file, line);
525   else
526     warning ("Cholmod warning '%s' at file:%s, line %d", message, file, line);
527 }
528 
529 static const R_CallMethodDef CallEntries[] = {
530   {"IRLB", (DL_FUNC) & IRLB, 16},
531   {NULL, NULL, 0}
532 };
533 
534 #ifdef HAVE_VISIBILITY_ATTRIBUTE
535 __attribute__ ((visibility ("default")))
536 #endif
537 void
R_init_irlba(DllInfo * dll)538 R_init_irlba (DllInfo * dll)
539 {
540 
541   R_RegisterCCallable("irlba", "orthog",
542                       (DL_FUNC) &orthog);
543   R_RegisterCCallable("irlba", "irlb",
544                       (DL_FUNC) &irlb);
545 
546 
547   R_registerRoutines (dll, NULL, CallEntries, NULL, NULL);
548   R_useDynamicSymbols (dll, 0);
549   M_R_cholmod_start (&chol_c);
550   chol_c.final_ll = 1;          /* LL' form of simplicial factorization */
551   /* need own error handler, that resets  final_ll (after *_defaults()) : */
552   chol_c.error_handler = irlba_R_cholmod_error;
553 }
554 
555 void
R_unload_irlba(DllInfo * dll)556 R_unload_irlba (DllInfo * dll)
557 {
558   M_cholmod_finish (&chol_c);
559 }
560 
561 
562 void
dsdmult(char transpose,int m,int n,void * a,double * b,double * c)563 dsdmult (char transpose, int m, int n, void * a, double *b, double *c)
564 {
565   DL_FUNC sdmult = R_GetCCallable ("Matrix", "cholmod_sdmult");
566   int t = transpose == 't' ? 1 : 0;
567   CHM_SP cha = (CHM_SP) a;
568 
569   cholmod_dense chb;
570   chb.nrow = transpose == 't' ? m : n;
571   chb.d = chb.nrow;
572   chb.ncol = 1;
573   chb.nzmax = chb.nrow;
574   chb.xtype = cha->xtype;
575   chb.dtype = 0;
576   chb.x = (void *) b;
577   chb.z = (void *) NULL;
578 
579   cholmod_dense chc;
580   chc.nrow = transpose == 't' ? n : m;
581   chc.d = chc.nrow;
582   chc.ncol = 1;
583   chc.nzmax = chc.nrow;
584   chc.xtype = cha->xtype;
585   chc.dtype = 0;
586   chc.x = (void *) c;
587   chc.z = (void *) NULL;
588 
589   double one[] = { 1, 0 }, zero[] = { 0, 0};
590   sdmult (cha, t, one, zero, &chb, &chc, &chol_c);
591 }
592