1 #include "mrilib.h"
2 
3    /*-------------------------------------------------------------------*/
4    /*** LASSO functions for AFNI programs.                            ***/
5    /*** We start with some utilities for setting LASSO parameters.    ***/
6    /*** Then some internal (static) functions for common necessities. ***/
7    /*-------------------------------------------------------------------*/
8 
9 static int lasso_verb = 0 ;
10 
11 /*----------------------------------------------------------------------------*/
12 
13 /** set the fixed value of lambda (flam);
14     note that flam will always be positive (never 0) **/
15 
16 /* this is set once-and-for-all so doesn't need to be thread-ized */
17 
18 static float flam = 0.666f ;
19 
THD_lasso_fixlam(float x)20 void THD_lasso_fixlam( float x ){ if( x > 0.0f ) flam = x ; }
21 
22 /*............................................................................*/
23 
24 /** set the convergence parameter (deps) [function not used at this time] **/
25 
26 static float deps = 0.0000321111f ;
27 
THD_lasso_setdeps(float x)28 void THD_lasso_setdeps( float x ){
29   deps = ( x >= 0.0000003f && x <= 0.1f ) ? x : 0.0000654321f ;
30 }
31 
32 /*............................................................................*/
33 
34 /** set this to 1 to do 'post-LASSO' re-regression [not used at this time] **/
35 
36 static int do_post = 0 ;
37 
THD_lasso_dopost(int x)38 void THD_lasso_dopost( int x ){ do_post = x ; }
39 
40 /*............................................................................*/
41 
42 /** set this to 1 to scale LASSO lambda by estimated sigma **/
43 
44 /* this is set once-and-for-all so doesn't need to be thread-ized */
45 
46 static int do_sigest = 0 ;
47 
THD_lasso_dosigest(int x)48 void THD_lasso_dosigest( int x ){ do_sigest = x ; }
49 
50 /*............................................................................*/
51 
52 /** set the entire lambda vector **/
53 
54 static floatvec *vlam = NULL ;
55 
56 /* this is set once-and-for-all so doesn't need to be thread-ized */
57 
THD_lasso_setlamvec(int nref,float * lam)58 void THD_lasso_setlamvec( int nref , float *lam )
59 {
60    register int ii ;
61 ENTRY("THD_lasso_setlamvec") ;
62 #pragma omp critical (MALLOC)
63    { KILL_floatvec(vlam) ; }
64    if( nref > 0 && lam != NULL ){
65 #pragma omp critical (MALLOC)
66      { MAKE_floatvec(vlam,nref) ; }
67      for( ii=0 ; ii < nref ; ii++ ) vlam->ar[ii] = lam[ii] ;
68    }
69    EXRETURN ;
70 }
71 
72 /*............................................................................*/
73 
74 /** set initial parameters estimates **/
75 
76 /* not used at this time, but is thread-ized for safety */
77 
78 AO_DEFINE_SCALAR(floatvec*,vpar) ;
79 #if 0
80 static floatvec *vpar = NULL ;
81 #endif
82 
THD_lasso_setparvec(int nref,float * par)83 void THD_lasso_setparvec( int nref , float *par )
84 {
85    register int ii ;
86 ENTRY("THD_lasso_setparvec") ;
87 #pragma omp critical (MALLOC)
88    { KILL_floatvec(AO_VALUE(vpar)) ; }
89    if( nref > 0 && par != NULL ){
90 #pragma omp critical (MALLOC)
91      { MAKE_floatvec(AO_VALUE(vpar),nref) ; }
92      for( ii=0 ; ii < nref ; ii++ ) AO_VALUE(vpar)->ar[ii] = par[ii] ;
93    }
94    EXRETURN ;
95 }
96 
97 /*----------------------------------------------------------------------------*/
98 /* Centro blocks = indexes over which the shrinkage is toward the centromean
99                    parameter (over the block) rather than toward 0.
100    Each centro block must have at least 3 entries.
101    NOTES: If the caller is an idiot, stupid things will happen; for example:
102             * If any of the entries of mb->ar[] are out of
103               the index range of the parameters (0..nref-1)
104             * If multiple centro blocks are used and share some indexes
105             * If an un-penalized index (mylam[i]==0) is provided
106           My suggestion is to avoid being an idiot.  [Aug 2021 - RWCox]
107 *//*--------------------------------------------------------------------------*/
108 
109 /* set once-and-for-all so doesn't need to be thread-ized */
110 
111 static int cenblok_num  = 0 ;
112 static intvec **cenblok = NULL ;
113 
THD_lasso_add_centro_block(intvec * mb)114 void THD_lasso_add_centro_block( intvec *mb )
115 {
116 ENTRY("THD_lasso_add_centro_block") ;
117 
118    if( mb == NULL ){  /* signal to clear all centro blocks */
119      int ii ;
120      if( cenblok != NULL ){
121        for( ii=0 ; ii < cenblok_num ; ii++ ){ KILL_intvec( cenblok[ii] ) ; }
122        free(cenblok) ;
123      }
124      EXRETURN ;
125    }
126 
127    if( mb->nar < 3 || mb->ar == NULL ) EXRETURN ;
128 
129    cenblok = (intvec **)realloc( cenblok, sizeof(intvec *)*(cenblok_num+1) ) ;
130 
131    COPY_intvec( cenblok[cenblok_num] , mb ) ;
132    cenblok_num++ ;
133    EXRETURN ;
134 }
135 
136 /*----------------------------------------------------------------------------*/
137 /* load block centros, if any; med[] entries not in a block are unchanged */
138 
load_block_centros(int nref,float * ppar,float * med)139 static void load_block_centros( int nref , float *ppar , float *med )
140 {
141    int bb , ii , kk, njj ;
142    float *bpar , mval ;
143 
144 ENTRY("load_block_centros") ;
145 
146    if( nref < 3 || ppar == NULL || med == NULL ) EXRETURN ;
147 
148    AAmemset( med , 0 , nref*sizeof(float) ) ;
149 
150    if( cenblok_num < 1 ) EXRETURN ;
151 
152    bpar = (float *)malloc(sizeof(float)*nref) ;
153 
154    /* loop over blocks [note subtract 1 from indexes in the intvecs */
155 
156    for( bb=0 ; bb < cenblok_num ; bb++ ){
157 
158      if( cenblok[bb]->nar < 3 ) continue ;          /* should be unpossible */
159 
160      for( njj=ii=0 ; ii < cenblok[bb]->nar ; ii++ ){ /* extract params */
161        kk = cenblok[bb]->ar[ii]-1 ;                  /* for this block */
162        if( kk >= 0 && kk < nref ) bpar[njj++] = ppar[kk] ;
163      }
164 
165      mval = centromean_float( njj , bpar ) ;  /* in cs_qmed.c */
166 
167 #if 0
168      if( lasso_verb && mval != 0.0f ){
169        char str[2048] ;
170        str[0] = '\0' ;
171        for( ii=0 ; ii < cenblok[bb]->nar ; ii++ ){
172         sprintf(str+strlen(str)," %d:%g",cenblok[bb]->ar[ii],bpar[ii]) ;
173        }
174        INFO_message("LASSO: LMB[%d] =%s => %g  njj=%d",bb,str,mval,njj) ;
175     }
176 #endif
177 
178      for( ii=0 ; ii < cenblok[bb]->nar ; ii++ ){     /* load med[] */
179        kk = cenblok[bb]->ar[ii]-1 ;                  /* for this block */
180        if( kk >= 0 && kk < nref ) med[kk] = mval ;
181      }
182 
183    }
184 
185    free(bpar) ;
186    EXRETURN ;
187 }
188 
189 /*----------------------------------------------------------------------------*/
190 
estimate_sigma(int npt,float * far)191 static float estimate_sigma( int npt , float *far )
192 {
193    float *dif , val,mad1=0.333f,mad2=0.333f ; int ii,nnz ;
194 
195    if( npt < 5 || far == NULL ) return 0.333f ;  /* half a milli-beast */
196 
197 #pragma omp critical (MALLOC)
198    { dif = (float *)malloc(sizeof(float)*npt) ; }
199 
200    /* MAD of 1st differences */
201 
202    for( nnz=ii=0 ; ii < npt-1 ; ii++ ){
203      val = far[ii+1]-far[ii] ; if( val != 0.0f ) dif[nnz++] = val ;
204    }
205    if( nnz == 1 ){
206      mad1 = fabsf(dif[0]) ;
207    } else if( nnz > 1 ){
208      qmedmad_float( nnz , dif , NULL , &mad1 ) ; mad1 *= 0.456f ;
209    }
210 
211    /* MAD of 2nd differences */
212 
213    for( nnz=ii=0 ; ii < npt-2 ; ii++ ){
214      val = 0.5f*(far[ii+2]+far[ii])-far[ii+1]; if( val != 0.0f ) dif[nnz++] = val;
215    }
216    if( nnz == 1 ){
217      mad2 = fabsf(dif[0]) ;
218    } else if( nnz > 1 ){
219      qmedmad_float( nnz , dif , NULL , &mad2 ) ; mad2 *= 0.567f ;
220    }
221 
222 #pragma omp critical (MALLOC)
223    { free(dif) ; }
224 
225    return MAX(mad1,mad2) ;
226 }
227 
228 /*----------------------------------------------------------------------------*/
229 /* Construct a local copy of lam[], and edit it softly. */
230 
edit_lamvec(int npt,int nref,float * lam)231 static float * edit_lamvec( int npt , int nref , float *lam )
232 {
233    float *mylam ;
234    int nfree , jj ;
235 
236 ENTRY("edit_lamvec") ;
237 
238 #pragma omp critical (MALLOC)
239    { mylam = (float *)calloc(sizeof(float),nref) ; }
240 
241    nfree = nref ;
242    if( lam != NULL ){                       /* copy input lam */
243      for( nfree=jj=0 ; jj < nref ; jj++ ){
244        mylam[jj] = MAX(0.0f,lam[jj]) ; if( mylam[jj] == 0.0f ) nfree++ ;
245      }
246    }
247 
248    if( nfree >= MIN(nref,npt) ){ /* no good input lam, so make one up */
249      nfree = 0 ;
250      if( vlam != NULL ){         /* take from user-supplied vector */
251        int nvlam = vlam->nar ;
252        for( jj=0 ; jj < nref ; jj++ ){
253          if( jj < nvlam ){
254            mylam[jj] = vlam->ar[jj] ;
255                 if( mylam[jj] <  0.0f ) mylam[jj] = flam ;
256            else if( mylam[jj] == 0.0f ) nfree++ ;
257          } else {
258            mylam[jj] = flam ;
259          }
260        }
261        if( nfree >= npt ){               /* too many free values */
262          for( jj=0 ; jj < nref ; jj++ )
263            if( mylam[jj] == 0.0f ) mylam[jj] = flam ;
264        }
265      } else {                            /* fixed value of lam */
266        for( jj=0 ; jj < nref ; jj++ ) mylam[jj] = flam ;
267      }
268    }
269 
270    RETURN(mylam) ;
271 }
272 
273 /*----------------------------------------------------------------------------*/
274 /* Compute the un-penalized solution to only the 'free' parameters,
275    as marked in fr[].  The results are stored back into ppar[]; un-free
276    parameters in ppar[] are not altered.
277 *//*--------------------------------------------------------------------------*/
278 
compute_free_param(int npt,float * far,int nref,float * ref[],int meth,float * ccon,int nfree,byte * fr,float * ppar)279 static void compute_free_param( int npt  , float *far   ,
280                                 int nref , float *ref[] ,
281                                 int meth , float *ccon  ,
282                                 int nfree, byte  *fr    , float *ppar )
283 {
284    float **qref,*qcon=NULL ; floatvec *qfit ; int nc,ii,jj ;
285 
286 ENTRY("compute_free_param") ;
287 
288    if( nfree <= 0 || nfree >= npt/2 || fr == NULL || ppar == NULL ) EXRETURN ;
289 
290 #pragma omp critical (MALLOC)
291    {                    qref = (float **)calloc(sizeof(float *),nfree) ;
292      if( ccon != NULL ) qcon = (float * )calloc(sizeof(float)  ,nfree) ; }
293 
294    /* select the marked regressors and constraints */
295 
296    for( nc=ii=jj=0 ; jj < nref ; jj++ ){
297      if( fr[jj] ){  /* use this parameter */
298        if( ccon != NULL && ccon[jj] != 0 ){ qcon[ii] = ccon[jj]; nc++; }
299        qref[ii++] = ref[jj] ;
300      }
301    }
302 
303 #pragma omp critical (MALLOC)
304    { if( nc == 0 ){ free(qcon); qcon = NULL; } }
305 
306    /* regress-ifization */
307 
308    qfit = THD_fitter( npt , far , nfree , qref , meth , qcon ) ;
309 
310    /* copy results into output */
311 
312    if( qfit != NULL ){
313      for( ii=jj=0 ; jj < nref ; jj++ ){
314        if( fr[jj] ) ppar[jj] = qfit->ar[ii++] ;
315      }
316    }
317 
318    /* vamoose the ranch */
319 
320 #pragma omp critical (MALLOC)
321    { free(qref) ;
322      if( qcon != NULL ) free(qcon) ;
323      if( qfit != NULL ) KILL_floatvec(qfit) ; }
324 
325    EXRETURN ;
326 }
327 
328 /*----------------------------------------------------------------------------*/
329 /* Check inputs for stupidities */
330 
check_inputs(int npt,float * far,int nref,float * ref[])331 static int check_inputs( int npt  , float *far ,
332                          int nref , float *ref[] )
333 {
334    int jj ;
335    if( npt <= 1 || far == NULL || nref <= 0 || ref == NULL ) return 1 ;
336    for( jj=0 ; jj < nref ; jj++ ) if( ref[jj] == NULL ) return 2 ;
337    return 0 ;
338 }
339 
340 /*----------------------------------------------------------------------------*/
341 
THD_lasso(int meth,int npt,float * far,int nref,float * ref[],float * lam,float * ccon)342 floatvec * THD_lasso( int meth   ,
343                       int npt    , float *far   ,
344                       int nref   , float *ref[] ,
345                       float *lam , float *ccon   )
346 {
347    switch( meth ){
348 
349      default:
350      case  2:
351      case -2: return THD_lasso_L2fit    ( npt,far , nref,ref , lam,ccon ) ;
352 
353      case  1:
354      case -1: return THD_sqrtlasso_L2fit( npt,far , nref,ref , lam,ccon ) ;
355 
356    }
357    return NULL ; /* unreachable */
358 }
359 
360 /*----------------------------------------------------------------------------*/
361 /* LASSO (L2 fit + L1 penalty) fit a vector to a set of reference vectors.
362    Input parameters are
363     * npt     = Length of input vectors
364     * nref    = Number of references
365     * far     = Vector to be fitted
366     * ref[k]  = k-th reference vector, for k=0..nref-1
367     * lam[k]  = L1 penalty factor for the k-th reference (non-negative)
368                 If lam == NULL, or all values are zero, then the value
369                 set by THD_lasso_fixlam() will be used instead.
370     * ccon[k] = If ccon != NULL, ccon[k] is a sign constraint on the k-th
371                 output coefficient: ccon[k] = 0 == no constraint
372                                             > 0 == coef #k must be >= 0
373                                             < 0 == coef #k must be <= 0
374    Unlike standard linear fitting, nref can be more than npt, since the
375    L1 penalty can force some coefficients to be exactly zero.  However,
376    at most npt-1 values of lam[] can be zero (or the problem is unsolvable).
377 
378    The return vector contains the nref coefficients.  If NULL is returned,
379    then something bad bad bad transpired and you should hang your head.
380 
381    TT Wu and K Lange.
382    Coordinate descent algorithms for LASSO penalized regression.
383    Annals of Applied Statistics, 2: 224-244 (2008).
384    http://arxiv.org/abs/0803.3876
385 *//*--------------------------------------------------------------------------*/
386 
THD_lasso_L2fit(int npt,float * far,int nref,float * ref[],float * lam,float * ccon)387 floatvec * THD_lasso_L2fit( int npt    , float *far   ,
388                             int nref   , float *ref[] ,
389                             float *lam , float *ccon   )
390 {
391    int ii,jj, nfree,nite,nimax,ndel , do_slam=0 ;
392    float *mylam, *ppar, *resd, *rsq, *rj, pj,dg,dsum,dsumx,ll ;
393    floatvec *qfit ; byte *fr ;
394    float *med , mval , pv , mv ;  /* for centro blocks [06 Aug 2021] */
395 
396 ENTRY("THD_lasso_L2fit") ;
397 
398    jj = check_inputs( npt , far , nref , ref ) ;
399    if( jj ){
400      static int ncall=0 ;
401      if( ncall < 2 ){ ERROR_message("LASSO: bad data and/or model"); ncall++; }
402      RETURN(NULL) ;
403    }
404 
405    /*--- construct a local copy of lam[], and edit it softly ---*/
406 
407    mylam = edit_lamvec( npt , nref , lam ) ;
408 
409    /*--- space for parameter iterates, etc (initialized to zero) ---*/
410 
411 #pragma omp critical (MALLOC)
412    { MAKE_floatvec(qfit,nref) ; ppar = qfit->ar ; /* parameters = output */
413      resd = (float *)calloc(sizeof(float),npt ) ; /* residuals */
414      rsq  = (float *)calloc(sizeof(float),nref) ; /* sums of squares */
415      fr   = (byte  *)calloc(sizeof(byte) ,nref) ; /* free list */
416      med  = (float *)calloc(sizeof(float),nref) ; /* block centros */
417    }
418 
419    /*--- Save 1/(sum of squares) of each ref column ---*/
420 
421    dsum = (do_sigest) ? estimate_sigma(npt,far) : 1.0f ;
422 
423    nfree = 0 ;                    /* number of unconstrained parameters */
424    for( jj=0 ; jj < nref ; jj++ ){
425      rj = ref[jj] ;
426      for( pj=ii=0 ; ii < npt ; ii++ ) pj += rj[ii]*rj[ii] ;
427      if( pj > 0.0f ){
428        rsq[jj] = 1.0f / pj ;
429        if( mylam[jj] == 0.0f ){   /* unconstrained parameter */
430          fr[jj] = 1 ;  nfree++ ;
431        } else {
432          mylam[jj] *= dsum * sqrtf(pj) ; /* scale for size of regressors */
433        }
434      }
435    }
436 
437    /*--- if any parameters are free (no L1 penalty),
438          initialize them by un-penalized least squares
439          (implicitly assuming all other parameters are zero) ---*/
440 
441    if( AO_VALUE(vpar) == NULL || AO_VALUE(vpar)->nar < nref ){
442      /* compute_free_param( npt,far,nref,ref,2,ccon , nfree,fr , ppar ) ; */
443    } else {
444      for( ii=0 ; ii < nref ; ii++ ) ppar[ii] = AO_VALUE(vpar)->ar[ii] ;
445    }
446 
447    /*--- initialize residuals ---*/
448 
449    for( ii=0 ; ii < npt ; ii++ ) resd[ii] = far[ii] ;          /* data */
450 
451    for( jj=0 ; jj < nref ; jj++ ){  /* subtract off fit of each column */
452      pj = ppar[jj] ; rj = ref[jj] ; /* with a nonzero parameter value */
453      if( pj != 0.0f ){
454        for( ii=0 ; ii < npt ; ii++ ) resd[ii] -= rj[ii]*pj ;
455      }
456    }
457 
458    /*--- if have a lot of references (relative to number of data points),
459          then increase lam[] for the first iterations to speed convergence ---*/
460 
461    do_slam = 3 * (nref > npt/4) ;                        /* first 3 */
462    if( do_slam ){                                        /*       | */
463      for( jj=0 ; jj < nref ; jj++ ) mylam[jj] *= 8.0f ;  /* 8 = 2^3 */
464    }
465 
466    /*---- outer iteration loop (until we are happy or worn out) ----*/
467 
468 #undef  CON     /* CON(j) is true if the constraint on ppar[j] is violated */
469 #define CON(j)  (ccon != NULL && ppar[j]*ccon[j] < 0.0f)
470 
471 #undef  CONP    /* CONP(j) is true if ppar[j] is supposed to be >= 0 */
472 #define CONP(j) (ccon != NULL && ccon[j] > 0.0f)
473 
474 #undef  CONN    /* CONN(j) is true if ppar[j] is supposed to be <= 0 */
475 #define CONN(j) (ccon != NULL && ccon[j] < 0.0f)
476 
477    { AO_DEFINE_SCALAR(int,ncall) ;
478      lasso_verb = ( AO_VALUE(ncall) < 2 || AO_VALUE(ncall)%10000 == 1 ) ;
479      AO_VALUE(ncall)++ ;
480    }
481 
482    ii = MAX(nref,npt) ; jj = MIN(nref,npt) ; nimax = 17 + 5*ii + 31*jj ;
483    dsumx = dsum = 1.0f ;
484 #if 0
485    if( lasso_verb && cenblok_num > 0 ) INFO_message("LASSO => start iterations") ;
486 #endif
487    for( nite=0 ; nite < nimax && dsum+dsumx > deps ; nite++ ){
488 
489      /*--- load block centros [06 Aug 2021] ---*/
490 
491      if( nite > 3 ) load_block_centros( nref , ppar , med ) ;
492 
493      /*-- cyclic inner loop over parameters --*/
494 
495      dsumx = dsum ;
496 #if 1
497      for( dsum=ndel=jj=0 ; jj < nref ; jj++ ){  /* dsum = sum of param deltas */
498 #else
499      for( dsum=ndel=0,jj=nref-1 ; jj >= 0 ; jj-- ){  /* dsum = sum of param deltas */
500 #endif
501 
502        if( rsq[jj] == 0.0f ) continue ; /* all zero column!? */
503        rj = ref[jj] ;                   /* j-th reference column */
504        pj = ppar[jj] ;                  /* current value of j-th parameter */
505        ll = mylam[jj] ;                 /* lambda for this param */
506 #if 1
507        mv = med[jj] ;                   /* shrinkage target (e.g., 0) */
508 #else
509        mv = 0.0f ;
510 #endif
511        pv = pj - mv ;                   /* param diff from shrinkage target */
512 
513        /* compute dg = -gradient of un-penalized function wrt ppar[jj] */
514        /*            = direction we want to step in                    */
515 
516        for( dg=ii=0 ; ii < npt ; ii++ ) dg += resd[ii] * rj[ii] ;
517 
518        /*- modify parameter down the gradient -*/
519 
520        if( ll == 0.0f ){          /* un-penalized parameter */
521 
522          ppar[jj] += dg*rsq[jj] ; if( CON(jj) ) ppar[jj] = 0.0f ;
523 
524        } else {                   /* penalized parameter */
525 
526          /* Extra -gradient is -lambda for pj > 0, and is +lambda for pj < 0. */
527          /* Merge this with dg, change ppar[jj], then see if we stepped thru */
528          /* zero (or hit a constraint) -- if so, stop ppar[jj] at zero.     */
529 
530          if( pv > 0.0f || (pv == 0.0f && dg > ll) ){    /* on the + side */
531            dg -= ll ; ppar[jj] += dg*rsq[jj] ;          /* shrink - way */
532            if( ppar[jj] < mv ) ppar[jj] = mv ;          /* went too far down? */
533            if( CON(jj)       ) ppar[jj] = 0.0f ;        /* violate constraint? */
534          } else if( pv < 0.0f || (pv == 0.0f && dg < -ll) ){ /* on the - side */
535            dg += ll ; ppar[jj] += dg*rsq[jj] ;               /* shrink + way */
536            if( ppar[jj] > mv ) ppar[jj] = mv ;               /* too far up? */
537            if( CON(jj)       ) ppar[jj] = 0.0f ;             /* constraint? */
538          }
539 
540        }
541 
542        dg = ppar[jj] - pj ;   /* change in parameter */
543        if( dg != 0.0f ){      /* update convergence test and residuals */
544          pj    = fabsf(ppar[jj]) + fabsf(pj) ;
545          dsum += fabsf(dg) / MAX(pj,0.001f) ; ndel++ ;
546          for( ii=0 ; ii < npt ; ii++ ) resd[ii] -= rj[ii] * dg ;
547        }
548 
549      } /*-- end of inner loop over parameters --*/
550 
551      /**** test for convergence somehow ***/
552 
553      if( ndel > 0 ) dsum *= (2.0f/ndel) ;
554 
555      if( do_slam ){     /* shrink lam[] back, if it was augmented */
556        do_slam-- ; dsum = 1.0f ;
557        for( jj=0 ; jj < nref ; jj++ ) mylam[jj] *= 0.5f ;
558      }
559 
560    } /*---- end of outer iteration loop ----*/
561 
562 #if 1
563    { if( lasso_verb ){
564        for( nfree=jj=0 ; jj < nref ; jj++ ) nfree += (ppar[jj] != 0.0f) ;
565        INFO_message("\nLASSO: nite=%d dsum=%g dsumx=%g nfree=%d/%d",nite,dsum,dsumx,nfree,nref) ;
566      }
567    }
568 #endif
569 
570    /*--- if 'post' computation is ordered, re-do the
571          regression without constraints, but using only
572          the references with non-zero weights from above ---*/
573 
574    if( do_post ){
575      nfree = 0 ;
576      for( jj=0 ; jj < nref ; jj++ ){  /* count and mark params to use */
577        fr[jj] = (ppar[jj] != 0.0f) ; nfree += fr[jj] ;
578      }
579      compute_free_param( npt,far,nref,ref,2,ccon , nfree,fr , ppar ) ;
580    }
581 
582    /*--- Loading up the truck and heading to Beverlee ---*/
583 
584 #pragma omp critical (MALLOC)
585    { free(fr) ; free(rsq) ; free(resd) ; free(mylam) ; free(med) ; }
586 
587    RETURN(qfit) ;
588 }
589 
590 /*----------------------------------------------------------------------------*/
591 /**------ minimizers of f(x) = sqrt(a*x*x+b*x+c) + d*x
592           valid for d*d < a and for 4*a*c-b*b > 0  [positive quadratic] -----**/
593 
594 #undef  XPLU             /* for a > d > 0 */
595 #define XPLU(a,b,c,d)                                               \
596    ( -( (b) * ((a)-(d)*(d))                                         \
597        + sqrtf( (d)*(d) * ((a)-(d)*(d)) * (4.0f*(a)*(c)-(b)*(b)) )  \
598       ) / ( 2.0f * (a) * ((a)-(d)*(d)) )                            \
599    )
600 
601 #undef  XMIN             /* for -a < d < 0 */
602 #define XMIN(a,b,c,d)                                               \
603    ( -( (b) * ((a)-(d)*(d))                                         \
604        - sqrtf( (d)*(d) * ((a)-(d)*(d)) * (4.0f*(a)*(c)-(b)*(b)) )  \
605       ) / ( 2.0f * (a) * ((a)-(d)*(d)) )                            \
606    )
607 
608 /*----------------------------------------------------------------------------*/
609 /* Square Root L2 LASSO, similar to L2 LASSO function directly above; see
610      A Belloni, V Chernozhukov, and L Wang.
611      Square-root LASSO: Pivotal recovery of sparse signals via conic programming.
612      http://arxiv.org/abs/1009.5689
613    This function uses a coordinate descent method similar to the pure
614    LASSO function earlier.  The key is that for 1 coordinate at a time, the
615    problem is reduced to minimizing a function of the form
616      f(x) = sqrt(a*x*x+b*x+c) + d*abs(x)
617    and the minimizer can be written in closed form as the solution to a
618    quadratic equation -- see XPLU and XMIN macros above.
619 *//*--------------------------------------------------------------------------*/
620 
621 floatvec * THD_sqrtlasso_L2fit( int npt    , float *far   ,
622                                 int nref   , float *ref[] ,
623                                 float *lam , float *ccon   )
624 {
625    int ii,jj, nfree,nite,nimax,ndel ;
626    float *mylam, *ppar, *resd, *rsq, *rj, pj,dg,dsum,dsumx ;
627    float rqsum,aa,bb,cc,ll,all , npinv , *ain,*qal ;
628    floatvec *qfit ; byte *fr ;
629 
630 ENTRY("THD_sqrtlasso_L2fit") ;
631 
632    /*--- check inputs for stupidities ---*/
633 
634    jj = check_inputs( npt , far , nref , ref ) ;
635    if( jj ){
636      static int ncall=0 ;
637      if( ncall < 2 ){ ERROR_message("SQRT LASSO: bad data and/or model"); ncall++; }
638      RETURN(NULL) ;
639    }
640 
641    /*--- construct a local copy of lam[], and edit it softly ---*/
642 
643    mylam = edit_lamvec( npt , nref , lam ) ;
644 
645    /*--- space for parameter iterates, etc (initialized to zero) ---*/
646 
647 #pragma omp critical (MALLOC)
648    { MAKE_floatvec(qfit,nref) ; ppar = qfit->ar ;   /* parameters = output */
649      resd = (float *)calloc(sizeof(float),npt ) ;   /* residuals */
650      rsq  = (float *)calloc(sizeof(float),nref) ;   /* sums of squares */
651      fr   = (byte  *)calloc(sizeof(byte) ,nref) ;   /* free list */
652      ain  = (float *)calloc(sizeof(float),nref) ;   /* -0.5 / rsq */
653      qal  = (float *)calloc(sizeof(float),nref) ;   /* step adjustment */
654    }
655 
656    /*--- Save sum of squares of each ref column ---*/
657 
658    npinv = 1.0f / (float)npt ;
659    nfree = 0 ;                          /* number of unconstrained parameters */
660    for( jj=0 ; jj < nref ; jj++ ){
661      rj = ref[jj] ;
662      for( pj=ii=0 ; ii < npt ; ii++ ) pj += rj[ii]*rj[ii] ; /* sum of squares */
663      rsq[jj] = pj * npinv ;                            /* average per data pt */
664      if( pj > 0.0f ){
665        if( mylam[jj] == 0.0f ){ fr[jj] = 1 ; nfree++ ; }     /* unconstrained */
666        ain[jj] = -0.5f / rsq[jj] ;
667      }
668    }
669 
670    /* scale and edit mylam to make sure it isn't too big for sqrt(quadratic) */
671 
672 #undef  AFAC
673 #define AFAC 0.222f
674 
675    cc = sqrtf(npinv)*AFAC ;
676    for( jj=0 ; jj < nref ; jj++ ){
677      ll = mylam[jj] ;
678      if( ll > 0.0f ){
679        aa = sqrtf(rsq[jj]); ll *= aa*cc; if( ll > AFAC*aa ) ll = AFAC*aa;
680        mylam[jj] = ll ;
681        qal[jj]   = ll*ll * 4.0f*rsq[jj]/(rsq[jj]-ll*ll) ;
682      }
683    }
684 
685    /*--- if any parameters are free (no L1 penalty),
686          initialize them by un-penalized least squares
687          (implicitly assuming all other parameters are zero) ---*/
688 
689    if( AO_VALUE(vpar) == NULL || AO_VALUE(vpar)->nar < nref ){
690      /* compute_free_param( npt,far,nref,ref,2,ccon , nfree,fr , ppar ) ; */
691    } else {
692      for( ii=0 ; ii < nref ; ii++ ) ppar[ii] = AO_VALUE(vpar)->ar[ii] ;
693    }
694 
695    /*--- initialize residuals ---*/
696 
697    for( ii=0 ; ii < npt ; ii++ ) resd[ii] = far[ii] ;          /* data */
698 
699    for( jj=0 ; jj < nref ; jj++ ){  /* subtract off fit of each column */
700      pj = ppar[jj] ; rj = ref[jj] ;  /* with a nonzero parameter value */
701      if( pj != 0.0f ){
702        for( ii=0 ; ii < npt ; ii++ ) resd[ii] -= rj[ii]*pj ;
703      }
704    }
705    for( rqsum=ii=0 ; ii < npt ; ii++ ) rqsum += resd[ii]*resd[ii] ;
706    rqsum *= npinv ;
707 
708    /*---- outer iteration loop (until we are happy or worn out) ----*/
709 
710    ii = MAX(nref,npt) ; jj = MIN(nref,npt) ; nimax = 17 + 5*ii + 31*jj ;
711    dsumx = dsum = 1.0f ;
712    for( nite=0 ; nite < nimax && dsum+dsumx > deps ; nite++ ){
713 
714      /*-- cyclic inner loop over parameters --*/
715 
716      dsumx = dsum ;                             /* save last value of dsum */
717 
718      for( dsum=ndel=jj=0 ; jj < nref ; jj++ ){  /* dsum = sum of param deltas */
719 
720        if( rsq[jj] == 0.0f ) continue ; /* all zero column!? */
721        rj = ref[jj] ;                   /* j-th reference column */
722        pj = ppar[jj] ;                  /* current value of j-th parameter */
723 
724        for( dg=ii=0 ; ii < npt ; ii++ ) dg += resd[ii] * rj[ii] ;
725        dg *= npinv ;
726 
727        /* want to minimize (wrt x) function sqrt(aa*x*x+bb*x+cc) + ll*abs(x) */
728 
729        aa = rsq[jj] ;
730        bb = -2.0f * (dg+aa*pj) ;
731        cc = rqsum + (2.0f*dg + aa*pj)*pj ;
732        ll = mylam[jj] ;
733 
734        /*- modify parameter -*/
735 
736        if( ll == 0.0f ){   /* un-penalized parameter */
737 
738          ppar[jj] = bb * ain[jj] ; if( CON(jj) ) ppar[jj] = 0.0f ;
739 
740        } else {
741 #if 0
742          float qq = ll * sqrtf(4.0f*aa*cc/(aa-ll*ll)) ;
743          if( pj > 0.0f || (pj == 0.0f && bb+qq < 0.0f) ){
744            ppar[jj] = (bb+qq) * ain[jj] ;      /* solution on positive side */
745            if( ppar[jj] < 0.0f || CON(jj) ) ppar[jj] = 0.0f ;
746          } else if( pj < 0.0f || (pj == 0.0f && bb-qq > 0.0f) ){
747            ppar[jj] = (bb-qq) * ain[jj] ;      /* solution on negative side */
748            if( ppar[jj] > 0.0f || CON(jj) ) ppar[jj] = 0.0f ;
749          }
750 #else
751          float qq = qal[jj] * cc ;              /* positive by construction */
752          if( pj > 0.0f ){
753            ppar[jj] = (bb+sqrtf(qq)) * ain[jj] ;           /* positive side */
754            if( ppar[jj] < 0.0f || CON(jj) ) ppar[jj] = 0.0f ;
755          } else if( pj < 0.0f ){
756            ppar[jj] = (bb-sqrtf(qq)) * ain[jj] ;           /* negative side */
757            if( ppar[jj] > 0.0f || CON(jj) ) ppar[jj] = 0.0f ;
758          } else {                                        /* initial pj == 0 */
759            if( bb*bb > qq ){         /* gradient step overpowers L1 penalty */
760              if( bb < 0.0f && !CONN(jj) ){                /* [note ain < 0] */
761                ppar[jj] = (bb+sqrtf(qq)) * ain[jj] ;      /* step to + side */
762              } else if( !CONP(jj) ){
763                ppar[jj] = (bb-sqrtf(qq)) * ain[jj] ;      /* step to - side */
764              }
765            }
766          }
767 #endif
768        }
769 
770        dg = ppar[jj] - pj ;   /* change in parameter */
771        if( dg != 0.0f ){      /* update convergence test and residuals */
772          pj    = fabsf(ppar[jj]) + fabsf(pj) ;
773          dsum += fabsf(dg) / MAX(pj,0.001f) ; ndel++ ;
774          for( rqsum=ii=0 ; ii < npt ; ii++ ){
775            resd[ii] -= rj[ii] * dg ; rqsum += resd[ii]*resd[ii] ;
776          }
777          rqsum *= npinv ;
778        }
779 
780      } /*-- end of inner loop over parameters --*/
781 
782      /**** test for convergence somehow ***/
783 
784      if( ndel > 0 ) dsum *= (2.0f/ndel) ;
785 
786    } /*---- end of outer iteration loop ----*/
787 
788 #if 1
789    { AO_DEFINE_SCALAR(int,ncall) ;
790      if( AO_VALUE(ncall) < 2 || AO_VALUE(ncall)%10000 == 1 ){
791        for( nfree=jj=0 ; jj < nref ; jj++ ) nfree += (ppar[jj] != 0.0f) ;
792        INFO_message("SQRTLASSO %d: nite=%d dsum=%g dsumx=%g nfree=%d/%d",AO_VALUE(ncall),nite,dsum,dsumx,nfree,nref) ;
793      }
794      AO_VALUE(ncall)++ ;
795    }
796 #endif
797 
798    /*--- if 'post' computation is ordered, re-do the
799          regression without constraints, but using only
800          the references with non-zero weights from above ---*/
801 
802    if( do_post ){
803      nfree = 0 ;
804      for( jj=0 ; jj < nref ; jj++ ){  /* count and mark params to use */
805        fr[jj] = (ppar[jj] != 0.0f) ; nfree += fr[jj] ;
806      }
807      compute_free_param( npt,far,nref,ref,2,ccon , nfree,fr , ppar ) ;
808    }
809 
810    /*--- Loading up the truck and heading to Beverlee ---*/
811 
812 #pragma omp critical (MALLOC)
813    { free(qal) ; free(ain) ; free(fr) ; free(rsq) ; free(resd) ; free(mylam) ; }
814 
815    RETURN(qfit) ;
816 }
817