1 #include "mrilib.h"
2
3 /*-------------------------------------------------------------------*/
4 /*** LASSO functions for AFNI programs. ***/
5 /*** We start with some utilities for setting LASSO parameters. ***/
6 /*** Then some internal (static) functions for common necessities. ***/
7 /*-------------------------------------------------------------------*/
8
9 static int lasso_verb = 0 ;
10
11 /*----------------------------------------------------------------------------*/
12
13 /** set the fixed value of lambda (flam);
14 note that flam will always be positive (never 0) **/
15
16 /* this is set once-and-for-all so doesn't need to be thread-ized */
17
18 static float flam = 0.666f ;
19
THD_lasso_fixlam(float x)20 void THD_lasso_fixlam( float x ){ if( x > 0.0f ) flam = x ; }
21
22 /*............................................................................*/
23
24 /** set the convergence parameter (deps) [function not used at this time] **/
25
26 static float deps = 0.0000321111f ;
27
THD_lasso_setdeps(float x)28 void THD_lasso_setdeps( float x ){
29 deps = ( x >= 0.0000003f && x <= 0.1f ) ? x : 0.0000654321f ;
30 }
31
32 /*............................................................................*/
33
34 /** set this to 1 to do 'post-LASSO' re-regression [not used at this time] **/
35
36 static int do_post = 0 ;
37
THD_lasso_dopost(int x)38 void THD_lasso_dopost( int x ){ do_post = x ; }
39
40 /*............................................................................*/
41
42 /** set this to 1 to scale LASSO lambda by estimated sigma **/
43
44 /* this is set once-and-for-all so doesn't need to be thread-ized */
45
46 static int do_sigest = 0 ;
47
THD_lasso_dosigest(int x)48 void THD_lasso_dosigest( int x ){ do_sigest = x ; }
49
50 /*............................................................................*/
51
52 /** set the entire lambda vector **/
53
54 static floatvec *vlam = NULL ;
55
56 /* this is set once-and-for-all so doesn't need to be thread-ized */
57
THD_lasso_setlamvec(int nref,float * lam)58 void THD_lasso_setlamvec( int nref , float *lam )
59 {
60 register int ii ;
61 ENTRY("THD_lasso_setlamvec") ;
62 #pragma omp critical (MALLOC)
63 { KILL_floatvec(vlam) ; }
64 if( nref > 0 && lam != NULL ){
65 #pragma omp critical (MALLOC)
66 { MAKE_floatvec(vlam,nref) ; }
67 for( ii=0 ; ii < nref ; ii++ ) vlam->ar[ii] = lam[ii] ;
68 }
69 EXRETURN ;
70 }
71
72 /*............................................................................*/
73
74 /** set initial parameters estimates **/
75
76 /* not used at this time, but is thread-ized for safety */
77
78 AO_DEFINE_SCALAR(floatvec*,vpar) ;
79 #if 0
80 static floatvec *vpar = NULL ;
81 #endif
82
THD_lasso_setparvec(int nref,float * par)83 void THD_lasso_setparvec( int nref , float *par )
84 {
85 register int ii ;
86 ENTRY("THD_lasso_setparvec") ;
87 #pragma omp critical (MALLOC)
88 { KILL_floatvec(AO_VALUE(vpar)) ; }
89 if( nref > 0 && par != NULL ){
90 #pragma omp critical (MALLOC)
91 { MAKE_floatvec(AO_VALUE(vpar),nref) ; }
92 for( ii=0 ; ii < nref ; ii++ ) AO_VALUE(vpar)->ar[ii] = par[ii] ;
93 }
94 EXRETURN ;
95 }
96
97 /*----------------------------------------------------------------------------*/
98 /* Centro blocks = indexes over which the shrinkage is toward the centromean
99 parameter (over the block) rather than toward 0.
100 Each centro block must have at least 3 entries.
101 NOTES: If the caller is an idiot, stupid things will happen; for example:
102 * If any of the entries of mb->ar[] are out of
103 the index range of the parameters (0..nref-1)
104 * If multiple centro blocks are used and share some indexes
105 * If an un-penalized index (mylam[i]==0) is provided
106 My suggestion is to avoid being an idiot. [Aug 2021 - RWCox]
107 *//*--------------------------------------------------------------------------*/
108
109 /* set once-and-for-all so doesn't need to be thread-ized */
110
111 static int cenblok_num = 0 ;
112 static intvec **cenblok = NULL ;
113
THD_lasso_add_centro_block(intvec * mb)114 void THD_lasso_add_centro_block( intvec *mb )
115 {
116 ENTRY("THD_lasso_add_centro_block") ;
117
118 if( mb == NULL ){ /* signal to clear all centro blocks */
119 int ii ;
120 if( cenblok != NULL ){
121 for( ii=0 ; ii < cenblok_num ; ii++ ){ KILL_intvec( cenblok[ii] ) ; }
122 free(cenblok) ;
123 }
124 EXRETURN ;
125 }
126
127 if( mb->nar < 3 || mb->ar == NULL ) EXRETURN ;
128
129 cenblok = (intvec **)realloc( cenblok, sizeof(intvec *)*(cenblok_num+1) ) ;
130
131 COPY_intvec( cenblok[cenblok_num] , mb ) ;
132 cenblok_num++ ;
133 EXRETURN ;
134 }
135
136 /*----------------------------------------------------------------------------*/
137 /* load block centros, if any; med[] entries not in a block are unchanged */
138
load_block_centros(int nref,float * ppar,float * med)139 static void load_block_centros( int nref , float *ppar , float *med )
140 {
141 int bb , ii , kk, njj ;
142 float *bpar , mval ;
143
144 ENTRY("load_block_centros") ;
145
146 if( nref < 3 || ppar == NULL || med == NULL ) EXRETURN ;
147
148 AAmemset( med , 0 , nref*sizeof(float) ) ;
149
150 if( cenblok_num < 1 ) EXRETURN ;
151
152 bpar = (float *)malloc(sizeof(float)*nref) ;
153
154 /* loop over blocks [note subtract 1 from indexes in the intvecs */
155
156 for( bb=0 ; bb < cenblok_num ; bb++ ){
157
158 if( cenblok[bb]->nar < 3 ) continue ; /* should be unpossible */
159
160 for( njj=ii=0 ; ii < cenblok[bb]->nar ; ii++ ){ /* extract params */
161 kk = cenblok[bb]->ar[ii]-1 ; /* for this block */
162 if( kk >= 0 && kk < nref ) bpar[njj++] = ppar[kk] ;
163 }
164
165 mval = centromean_float( njj , bpar ) ; /* in cs_qmed.c */
166
167 #if 0
168 if( lasso_verb && mval != 0.0f ){
169 char str[2048] ;
170 str[0] = '\0' ;
171 for( ii=0 ; ii < cenblok[bb]->nar ; ii++ ){
172 sprintf(str+strlen(str)," %d:%g",cenblok[bb]->ar[ii],bpar[ii]) ;
173 }
174 INFO_message("LASSO: LMB[%d] =%s => %g njj=%d",bb,str,mval,njj) ;
175 }
176 #endif
177
178 for( ii=0 ; ii < cenblok[bb]->nar ; ii++ ){ /* load med[] */
179 kk = cenblok[bb]->ar[ii]-1 ; /* for this block */
180 if( kk >= 0 && kk < nref ) med[kk] = mval ;
181 }
182
183 }
184
185 free(bpar) ;
186 EXRETURN ;
187 }
188
189 /*----------------------------------------------------------------------------*/
190
estimate_sigma(int npt,float * far)191 static float estimate_sigma( int npt , float *far )
192 {
193 float *dif , val,mad1=0.333f,mad2=0.333f ; int ii,nnz ;
194
195 if( npt < 5 || far == NULL ) return 0.333f ; /* half a milli-beast */
196
197 #pragma omp critical (MALLOC)
198 { dif = (float *)malloc(sizeof(float)*npt) ; }
199
200 /* MAD of 1st differences */
201
202 for( nnz=ii=0 ; ii < npt-1 ; ii++ ){
203 val = far[ii+1]-far[ii] ; if( val != 0.0f ) dif[nnz++] = val ;
204 }
205 if( nnz == 1 ){
206 mad1 = fabsf(dif[0]) ;
207 } else if( nnz > 1 ){
208 qmedmad_float( nnz , dif , NULL , &mad1 ) ; mad1 *= 0.456f ;
209 }
210
211 /* MAD of 2nd differences */
212
213 for( nnz=ii=0 ; ii < npt-2 ; ii++ ){
214 val = 0.5f*(far[ii+2]+far[ii])-far[ii+1]; if( val != 0.0f ) dif[nnz++] = val;
215 }
216 if( nnz == 1 ){
217 mad2 = fabsf(dif[0]) ;
218 } else if( nnz > 1 ){
219 qmedmad_float( nnz , dif , NULL , &mad2 ) ; mad2 *= 0.567f ;
220 }
221
222 #pragma omp critical (MALLOC)
223 { free(dif) ; }
224
225 return MAX(mad1,mad2) ;
226 }
227
228 /*----------------------------------------------------------------------------*/
229 /* Construct a local copy of lam[], and edit it softly. */
230
edit_lamvec(int npt,int nref,float * lam)231 static float * edit_lamvec( int npt , int nref , float *lam )
232 {
233 float *mylam ;
234 int nfree , jj ;
235
236 ENTRY("edit_lamvec") ;
237
238 #pragma omp critical (MALLOC)
239 { mylam = (float *)calloc(sizeof(float),nref) ; }
240
241 nfree = nref ;
242 if( lam != NULL ){ /* copy input lam */
243 for( nfree=jj=0 ; jj < nref ; jj++ ){
244 mylam[jj] = MAX(0.0f,lam[jj]) ; if( mylam[jj] == 0.0f ) nfree++ ;
245 }
246 }
247
248 if( nfree >= MIN(nref,npt) ){ /* no good input lam, so make one up */
249 nfree = 0 ;
250 if( vlam != NULL ){ /* take from user-supplied vector */
251 int nvlam = vlam->nar ;
252 for( jj=0 ; jj < nref ; jj++ ){
253 if( jj < nvlam ){
254 mylam[jj] = vlam->ar[jj] ;
255 if( mylam[jj] < 0.0f ) mylam[jj] = flam ;
256 else if( mylam[jj] == 0.0f ) nfree++ ;
257 } else {
258 mylam[jj] = flam ;
259 }
260 }
261 if( nfree >= npt ){ /* too many free values */
262 for( jj=0 ; jj < nref ; jj++ )
263 if( mylam[jj] == 0.0f ) mylam[jj] = flam ;
264 }
265 } else { /* fixed value of lam */
266 for( jj=0 ; jj < nref ; jj++ ) mylam[jj] = flam ;
267 }
268 }
269
270 RETURN(mylam) ;
271 }
272
273 /*----------------------------------------------------------------------------*/
274 /* Compute the un-penalized solution to only the 'free' parameters,
275 as marked in fr[]. The results are stored back into ppar[]; un-free
276 parameters in ppar[] are not altered.
277 *//*--------------------------------------------------------------------------*/
278
compute_free_param(int npt,float * far,int nref,float * ref[],int meth,float * ccon,int nfree,byte * fr,float * ppar)279 static void compute_free_param( int npt , float *far ,
280 int nref , float *ref[] ,
281 int meth , float *ccon ,
282 int nfree, byte *fr , float *ppar )
283 {
284 float **qref,*qcon=NULL ; floatvec *qfit ; int nc,ii,jj ;
285
286 ENTRY("compute_free_param") ;
287
288 if( nfree <= 0 || nfree >= npt/2 || fr == NULL || ppar == NULL ) EXRETURN ;
289
290 #pragma omp critical (MALLOC)
291 { qref = (float **)calloc(sizeof(float *),nfree) ;
292 if( ccon != NULL ) qcon = (float * )calloc(sizeof(float) ,nfree) ; }
293
294 /* select the marked regressors and constraints */
295
296 for( nc=ii=jj=0 ; jj < nref ; jj++ ){
297 if( fr[jj] ){ /* use this parameter */
298 if( ccon != NULL && ccon[jj] != 0 ){ qcon[ii] = ccon[jj]; nc++; }
299 qref[ii++] = ref[jj] ;
300 }
301 }
302
303 #pragma omp critical (MALLOC)
304 { if( nc == 0 ){ free(qcon); qcon = NULL; } }
305
306 /* regress-ifization */
307
308 qfit = THD_fitter( npt , far , nfree , qref , meth , qcon ) ;
309
310 /* copy results into output */
311
312 if( qfit != NULL ){
313 for( ii=jj=0 ; jj < nref ; jj++ ){
314 if( fr[jj] ) ppar[jj] = qfit->ar[ii++] ;
315 }
316 }
317
318 /* vamoose the ranch */
319
320 #pragma omp critical (MALLOC)
321 { free(qref) ;
322 if( qcon != NULL ) free(qcon) ;
323 if( qfit != NULL ) KILL_floatvec(qfit) ; }
324
325 EXRETURN ;
326 }
327
328 /*----------------------------------------------------------------------------*/
329 /* Check inputs for stupidities */
330
check_inputs(int npt,float * far,int nref,float * ref[])331 static int check_inputs( int npt , float *far ,
332 int nref , float *ref[] )
333 {
334 int jj ;
335 if( npt <= 1 || far == NULL || nref <= 0 || ref == NULL ) return 1 ;
336 for( jj=0 ; jj < nref ; jj++ ) if( ref[jj] == NULL ) return 2 ;
337 return 0 ;
338 }
339
340 /*----------------------------------------------------------------------------*/
341
THD_lasso(int meth,int npt,float * far,int nref,float * ref[],float * lam,float * ccon)342 floatvec * THD_lasso( int meth ,
343 int npt , float *far ,
344 int nref , float *ref[] ,
345 float *lam , float *ccon )
346 {
347 switch( meth ){
348
349 default:
350 case 2:
351 case -2: return THD_lasso_L2fit ( npt,far , nref,ref , lam,ccon ) ;
352
353 case 1:
354 case -1: return THD_sqrtlasso_L2fit( npt,far , nref,ref , lam,ccon ) ;
355
356 }
357 return NULL ; /* unreachable */
358 }
359
360 /*----------------------------------------------------------------------------*/
361 /* LASSO (L2 fit + L1 penalty) fit a vector to a set of reference vectors.
362 Input parameters are
363 * npt = Length of input vectors
364 * nref = Number of references
365 * far = Vector to be fitted
366 * ref[k] = k-th reference vector, for k=0..nref-1
367 * lam[k] = L1 penalty factor for the k-th reference (non-negative)
368 If lam == NULL, or all values are zero, then the value
369 set by THD_lasso_fixlam() will be used instead.
370 * ccon[k] = If ccon != NULL, ccon[k] is a sign constraint on the k-th
371 output coefficient: ccon[k] = 0 == no constraint
372 > 0 == coef #k must be >= 0
373 < 0 == coef #k must be <= 0
374 Unlike standard linear fitting, nref can be more than npt, since the
375 L1 penalty can force some coefficients to be exactly zero. However,
376 at most npt-1 values of lam[] can be zero (or the problem is unsolvable).
377
378 The return vector contains the nref coefficients. If NULL is returned,
379 then something bad bad bad transpired and you should hang your head.
380
381 TT Wu and K Lange.
382 Coordinate descent algorithms for LASSO penalized regression.
383 Annals of Applied Statistics, 2: 224-244 (2008).
384 http://arxiv.org/abs/0803.3876
385 *//*--------------------------------------------------------------------------*/
386
THD_lasso_L2fit(int npt,float * far,int nref,float * ref[],float * lam,float * ccon)387 floatvec * THD_lasso_L2fit( int npt , float *far ,
388 int nref , float *ref[] ,
389 float *lam , float *ccon )
390 {
391 int ii,jj, nfree,nite,nimax,ndel , do_slam=0 ;
392 float *mylam, *ppar, *resd, *rsq, *rj, pj,dg,dsum,dsumx,ll ;
393 floatvec *qfit ; byte *fr ;
394 float *med , mval , pv , mv ; /* for centro blocks [06 Aug 2021] */
395
396 ENTRY("THD_lasso_L2fit") ;
397
398 jj = check_inputs( npt , far , nref , ref ) ;
399 if( jj ){
400 static int ncall=0 ;
401 if( ncall < 2 ){ ERROR_message("LASSO: bad data and/or model"); ncall++; }
402 RETURN(NULL) ;
403 }
404
405 /*--- construct a local copy of lam[], and edit it softly ---*/
406
407 mylam = edit_lamvec( npt , nref , lam ) ;
408
409 /*--- space for parameter iterates, etc (initialized to zero) ---*/
410
411 #pragma omp critical (MALLOC)
412 { MAKE_floatvec(qfit,nref) ; ppar = qfit->ar ; /* parameters = output */
413 resd = (float *)calloc(sizeof(float),npt ) ; /* residuals */
414 rsq = (float *)calloc(sizeof(float),nref) ; /* sums of squares */
415 fr = (byte *)calloc(sizeof(byte) ,nref) ; /* free list */
416 med = (float *)calloc(sizeof(float),nref) ; /* block centros */
417 }
418
419 /*--- Save 1/(sum of squares) of each ref column ---*/
420
421 dsum = (do_sigest) ? estimate_sigma(npt,far) : 1.0f ;
422
423 nfree = 0 ; /* number of unconstrained parameters */
424 for( jj=0 ; jj < nref ; jj++ ){
425 rj = ref[jj] ;
426 for( pj=ii=0 ; ii < npt ; ii++ ) pj += rj[ii]*rj[ii] ;
427 if( pj > 0.0f ){
428 rsq[jj] = 1.0f / pj ;
429 if( mylam[jj] == 0.0f ){ /* unconstrained parameter */
430 fr[jj] = 1 ; nfree++ ;
431 } else {
432 mylam[jj] *= dsum * sqrtf(pj) ; /* scale for size of regressors */
433 }
434 }
435 }
436
437 /*--- if any parameters are free (no L1 penalty),
438 initialize them by un-penalized least squares
439 (implicitly assuming all other parameters are zero) ---*/
440
441 if( AO_VALUE(vpar) == NULL || AO_VALUE(vpar)->nar < nref ){
442 /* compute_free_param( npt,far,nref,ref,2,ccon , nfree,fr , ppar ) ; */
443 } else {
444 for( ii=0 ; ii < nref ; ii++ ) ppar[ii] = AO_VALUE(vpar)->ar[ii] ;
445 }
446
447 /*--- initialize residuals ---*/
448
449 for( ii=0 ; ii < npt ; ii++ ) resd[ii] = far[ii] ; /* data */
450
451 for( jj=0 ; jj < nref ; jj++ ){ /* subtract off fit of each column */
452 pj = ppar[jj] ; rj = ref[jj] ; /* with a nonzero parameter value */
453 if( pj != 0.0f ){
454 for( ii=0 ; ii < npt ; ii++ ) resd[ii] -= rj[ii]*pj ;
455 }
456 }
457
458 /*--- if have a lot of references (relative to number of data points),
459 then increase lam[] for the first iterations to speed convergence ---*/
460
461 do_slam = 3 * (nref > npt/4) ; /* first 3 */
462 if( do_slam ){ /* | */
463 for( jj=0 ; jj < nref ; jj++ ) mylam[jj] *= 8.0f ; /* 8 = 2^3 */
464 }
465
466 /*---- outer iteration loop (until we are happy or worn out) ----*/
467
468 #undef CON /* CON(j) is true if the constraint on ppar[j] is violated */
469 #define CON(j) (ccon != NULL && ppar[j]*ccon[j] < 0.0f)
470
471 #undef CONP /* CONP(j) is true if ppar[j] is supposed to be >= 0 */
472 #define CONP(j) (ccon != NULL && ccon[j] > 0.0f)
473
474 #undef CONN /* CONN(j) is true if ppar[j] is supposed to be <= 0 */
475 #define CONN(j) (ccon != NULL && ccon[j] < 0.0f)
476
477 { AO_DEFINE_SCALAR(int,ncall) ;
478 lasso_verb = ( AO_VALUE(ncall) < 2 || AO_VALUE(ncall)%10000 == 1 ) ;
479 AO_VALUE(ncall)++ ;
480 }
481
482 ii = MAX(nref,npt) ; jj = MIN(nref,npt) ; nimax = 17 + 5*ii + 31*jj ;
483 dsumx = dsum = 1.0f ;
484 #if 0
485 if( lasso_verb && cenblok_num > 0 ) INFO_message("LASSO => start iterations") ;
486 #endif
487 for( nite=0 ; nite < nimax && dsum+dsumx > deps ; nite++ ){
488
489 /*--- load block centros [06 Aug 2021] ---*/
490
491 if( nite > 3 ) load_block_centros( nref , ppar , med ) ;
492
493 /*-- cyclic inner loop over parameters --*/
494
495 dsumx = dsum ;
496 #if 1
497 for( dsum=ndel=jj=0 ; jj < nref ; jj++ ){ /* dsum = sum of param deltas */
498 #else
499 for( dsum=ndel=0,jj=nref-1 ; jj >= 0 ; jj-- ){ /* dsum = sum of param deltas */
500 #endif
501
502 if( rsq[jj] == 0.0f ) continue ; /* all zero column!? */
503 rj = ref[jj] ; /* j-th reference column */
504 pj = ppar[jj] ; /* current value of j-th parameter */
505 ll = mylam[jj] ; /* lambda for this param */
506 #if 1
507 mv = med[jj] ; /* shrinkage target (e.g., 0) */
508 #else
509 mv = 0.0f ;
510 #endif
511 pv = pj - mv ; /* param diff from shrinkage target */
512
513 /* compute dg = -gradient of un-penalized function wrt ppar[jj] */
514 /* = direction we want to step in */
515
516 for( dg=ii=0 ; ii < npt ; ii++ ) dg += resd[ii] * rj[ii] ;
517
518 /*- modify parameter down the gradient -*/
519
520 if( ll == 0.0f ){ /* un-penalized parameter */
521
522 ppar[jj] += dg*rsq[jj] ; if( CON(jj) ) ppar[jj] = 0.0f ;
523
524 } else { /* penalized parameter */
525
526 /* Extra -gradient is -lambda for pj > 0, and is +lambda for pj < 0. */
527 /* Merge this with dg, change ppar[jj], then see if we stepped thru */
528 /* zero (or hit a constraint) -- if so, stop ppar[jj] at zero. */
529
530 if( pv > 0.0f || (pv == 0.0f && dg > ll) ){ /* on the + side */
531 dg -= ll ; ppar[jj] += dg*rsq[jj] ; /* shrink - way */
532 if( ppar[jj] < mv ) ppar[jj] = mv ; /* went too far down? */
533 if( CON(jj) ) ppar[jj] = 0.0f ; /* violate constraint? */
534 } else if( pv < 0.0f || (pv == 0.0f && dg < -ll) ){ /* on the - side */
535 dg += ll ; ppar[jj] += dg*rsq[jj] ; /* shrink + way */
536 if( ppar[jj] > mv ) ppar[jj] = mv ; /* too far up? */
537 if( CON(jj) ) ppar[jj] = 0.0f ; /* constraint? */
538 }
539
540 }
541
542 dg = ppar[jj] - pj ; /* change in parameter */
543 if( dg != 0.0f ){ /* update convergence test and residuals */
544 pj = fabsf(ppar[jj]) + fabsf(pj) ;
545 dsum += fabsf(dg) / MAX(pj,0.001f) ; ndel++ ;
546 for( ii=0 ; ii < npt ; ii++ ) resd[ii] -= rj[ii] * dg ;
547 }
548
549 } /*-- end of inner loop over parameters --*/
550
551 /**** test for convergence somehow ***/
552
553 if( ndel > 0 ) dsum *= (2.0f/ndel) ;
554
555 if( do_slam ){ /* shrink lam[] back, if it was augmented */
556 do_slam-- ; dsum = 1.0f ;
557 for( jj=0 ; jj < nref ; jj++ ) mylam[jj] *= 0.5f ;
558 }
559
560 } /*---- end of outer iteration loop ----*/
561
562 #if 1
563 { if( lasso_verb ){
564 for( nfree=jj=0 ; jj < nref ; jj++ ) nfree += (ppar[jj] != 0.0f) ;
565 INFO_message("\nLASSO: nite=%d dsum=%g dsumx=%g nfree=%d/%d",nite,dsum,dsumx,nfree,nref) ;
566 }
567 }
568 #endif
569
570 /*--- if 'post' computation is ordered, re-do the
571 regression without constraints, but using only
572 the references with non-zero weights from above ---*/
573
574 if( do_post ){
575 nfree = 0 ;
576 for( jj=0 ; jj < nref ; jj++ ){ /* count and mark params to use */
577 fr[jj] = (ppar[jj] != 0.0f) ; nfree += fr[jj] ;
578 }
579 compute_free_param( npt,far,nref,ref,2,ccon , nfree,fr , ppar ) ;
580 }
581
582 /*--- Loading up the truck and heading to Beverlee ---*/
583
584 #pragma omp critical (MALLOC)
585 { free(fr) ; free(rsq) ; free(resd) ; free(mylam) ; free(med) ; }
586
587 RETURN(qfit) ;
588 }
589
590 /*----------------------------------------------------------------------------*/
591 /**------ minimizers of f(x) = sqrt(a*x*x+b*x+c) + d*x
592 valid for d*d < a and for 4*a*c-b*b > 0 [positive quadratic] -----**/
593
594 #undef XPLU /* for a > d > 0 */
595 #define XPLU(a,b,c,d) \
596 ( -( (b) * ((a)-(d)*(d)) \
597 + sqrtf( (d)*(d) * ((a)-(d)*(d)) * (4.0f*(a)*(c)-(b)*(b)) ) \
598 ) / ( 2.0f * (a) * ((a)-(d)*(d)) ) \
599 )
600
601 #undef XMIN /* for -a < d < 0 */
602 #define XMIN(a,b,c,d) \
603 ( -( (b) * ((a)-(d)*(d)) \
604 - sqrtf( (d)*(d) * ((a)-(d)*(d)) * (4.0f*(a)*(c)-(b)*(b)) ) \
605 ) / ( 2.0f * (a) * ((a)-(d)*(d)) ) \
606 )
607
608 /*----------------------------------------------------------------------------*/
609 /* Square Root L2 LASSO, similar to L2 LASSO function directly above; see
610 A Belloni, V Chernozhukov, and L Wang.
611 Square-root LASSO: Pivotal recovery of sparse signals via conic programming.
612 http://arxiv.org/abs/1009.5689
613 This function uses a coordinate descent method similar to the pure
614 LASSO function earlier. The key is that for 1 coordinate at a time, the
615 problem is reduced to minimizing a function of the form
616 f(x) = sqrt(a*x*x+b*x+c) + d*abs(x)
617 and the minimizer can be written in closed form as the solution to a
618 quadratic equation -- see XPLU and XMIN macros above.
619 *//*--------------------------------------------------------------------------*/
620
621 floatvec * THD_sqrtlasso_L2fit( int npt , float *far ,
622 int nref , float *ref[] ,
623 float *lam , float *ccon )
624 {
625 int ii,jj, nfree,nite,nimax,ndel ;
626 float *mylam, *ppar, *resd, *rsq, *rj, pj,dg,dsum,dsumx ;
627 float rqsum,aa,bb,cc,ll,all , npinv , *ain,*qal ;
628 floatvec *qfit ; byte *fr ;
629
630 ENTRY("THD_sqrtlasso_L2fit") ;
631
632 /*--- check inputs for stupidities ---*/
633
634 jj = check_inputs( npt , far , nref , ref ) ;
635 if( jj ){
636 static int ncall=0 ;
637 if( ncall < 2 ){ ERROR_message("SQRT LASSO: bad data and/or model"); ncall++; }
638 RETURN(NULL) ;
639 }
640
641 /*--- construct a local copy of lam[], and edit it softly ---*/
642
643 mylam = edit_lamvec( npt , nref , lam ) ;
644
645 /*--- space for parameter iterates, etc (initialized to zero) ---*/
646
647 #pragma omp critical (MALLOC)
648 { MAKE_floatvec(qfit,nref) ; ppar = qfit->ar ; /* parameters = output */
649 resd = (float *)calloc(sizeof(float),npt ) ; /* residuals */
650 rsq = (float *)calloc(sizeof(float),nref) ; /* sums of squares */
651 fr = (byte *)calloc(sizeof(byte) ,nref) ; /* free list */
652 ain = (float *)calloc(sizeof(float),nref) ; /* -0.5 / rsq */
653 qal = (float *)calloc(sizeof(float),nref) ; /* step adjustment */
654 }
655
656 /*--- Save sum of squares of each ref column ---*/
657
658 npinv = 1.0f / (float)npt ;
659 nfree = 0 ; /* number of unconstrained parameters */
660 for( jj=0 ; jj < nref ; jj++ ){
661 rj = ref[jj] ;
662 for( pj=ii=0 ; ii < npt ; ii++ ) pj += rj[ii]*rj[ii] ; /* sum of squares */
663 rsq[jj] = pj * npinv ; /* average per data pt */
664 if( pj > 0.0f ){
665 if( mylam[jj] == 0.0f ){ fr[jj] = 1 ; nfree++ ; } /* unconstrained */
666 ain[jj] = -0.5f / rsq[jj] ;
667 }
668 }
669
670 /* scale and edit mylam to make sure it isn't too big for sqrt(quadratic) */
671
672 #undef AFAC
673 #define AFAC 0.222f
674
675 cc = sqrtf(npinv)*AFAC ;
676 for( jj=0 ; jj < nref ; jj++ ){
677 ll = mylam[jj] ;
678 if( ll > 0.0f ){
679 aa = sqrtf(rsq[jj]); ll *= aa*cc; if( ll > AFAC*aa ) ll = AFAC*aa;
680 mylam[jj] = ll ;
681 qal[jj] = ll*ll * 4.0f*rsq[jj]/(rsq[jj]-ll*ll) ;
682 }
683 }
684
685 /*--- if any parameters are free (no L1 penalty),
686 initialize them by un-penalized least squares
687 (implicitly assuming all other parameters are zero) ---*/
688
689 if( AO_VALUE(vpar) == NULL || AO_VALUE(vpar)->nar < nref ){
690 /* compute_free_param( npt,far,nref,ref,2,ccon , nfree,fr , ppar ) ; */
691 } else {
692 for( ii=0 ; ii < nref ; ii++ ) ppar[ii] = AO_VALUE(vpar)->ar[ii] ;
693 }
694
695 /*--- initialize residuals ---*/
696
697 for( ii=0 ; ii < npt ; ii++ ) resd[ii] = far[ii] ; /* data */
698
699 for( jj=0 ; jj < nref ; jj++ ){ /* subtract off fit of each column */
700 pj = ppar[jj] ; rj = ref[jj] ; /* with a nonzero parameter value */
701 if( pj != 0.0f ){
702 for( ii=0 ; ii < npt ; ii++ ) resd[ii] -= rj[ii]*pj ;
703 }
704 }
705 for( rqsum=ii=0 ; ii < npt ; ii++ ) rqsum += resd[ii]*resd[ii] ;
706 rqsum *= npinv ;
707
708 /*---- outer iteration loop (until we are happy or worn out) ----*/
709
710 ii = MAX(nref,npt) ; jj = MIN(nref,npt) ; nimax = 17 + 5*ii + 31*jj ;
711 dsumx = dsum = 1.0f ;
712 for( nite=0 ; nite < nimax && dsum+dsumx > deps ; nite++ ){
713
714 /*-- cyclic inner loop over parameters --*/
715
716 dsumx = dsum ; /* save last value of dsum */
717
718 for( dsum=ndel=jj=0 ; jj < nref ; jj++ ){ /* dsum = sum of param deltas */
719
720 if( rsq[jj] == 0.0f ) continue ; /* all zero column!? */
721 rj = ref[jj] ; /* j-th reference column */
722 pj = ppar[jj] ; /* current value of j-th parameter */
723
724 for( dg=ii=0 ; ii < npt ; ii++ ) dg += resd[ii] * rj[ii] ;
725 dg *= npinv ;
726
727 /* want to minimize (wrt x) function sqrt(aa*x*x+bb*x+cc) + ll*abs(x) */
728
729 aa = rsq[jj] ;
730 bb = -2.0f * (dg+aa*pj) ;
731 cc = rqsum + (2.0f*dg + aa*pj)*pj ;
732 ll = mylam[jj] ;
733
734 /*- modify parameter -*/
735
736 if( ll == 0.0f ){ /* un-penalized parameter */
737
738 ppar[jj] = bb * ain[jj] ; if( CON(jj) ) ppar[jj] = 0.0f ;
739
740 } else {
741 #if 0
742 float qq = ll * sqrtf(4.0f*aa*cc/(aa-ll*ll)) ;
743 if( pj > 0.0f || (pj == 0.0f && bb+qq < 0.0f) ){
744 ppar[jj] = (bb+qq) * ain[jj] ; /* solution on positive side */
745 if( ppar[jj] < 0.0f || CON(jj) ) ppar[jj] = 0.0f ;
746 } else if( pj < 0.0f || (pj == 0.0f && bb-qq > 0.0f) ){
747 ppar[jj] = (bb-qq) * ain[jj] ; /* solution on negative side */
748 if( ppar[jj] > 0.0f || CON(jj) ) ppar[jj] = 0.0f ;
749 }
750 #else
751 float qq = qal[jj] * cc ; /* positive by construction */
752 if( pj > 0.0f ){
753 ppar[jj] = (bb+sqrtf(qq)) * ain[jj] ; /* positive side */
754 if( ppar[jj] < 0.0f || CON(jj) ) ppar[jj] = 0.0f ;
755 } else if( pj < 0.0f ){
756 ppar[jj] = (bb-sqrtf(qq)) * ain[jj] ; /* negative side */
757 if( ppar[jj] > 0.0f || CON(jj) ) ppar[jj] = 0.0f ;
758 } else { /* initial pj == 0 */
759 if( bb*bb > qq ){ /* gradient step overpowers L1 penalty */
760 if( bb < 0.0f && !CONN(jj) ){ /* [note ain < 0] */
761 ppar[jj] = (bb+sqrtf(qq)) * ain[jj] ; /* step to + side */
762 } else if( !CONP(jj) ){
763 ppar[jj] = (bb-sqrtf(qq)) * ain[jj] ; /* step to - side */
764 }
765 }
766 }
767 #endif
768 }
769
770 dg = ppar[jj] - pj ; /* change in parameter */
771 if( dg != 0.0f ){ /* update convergence test and residuals */
772 pj = fabsf(ppar[jj]) + fabsf(pj) ;
773 dsum += fabsf(dg) / MAX(pj,0.001f) ; ndel++ ;
774 for( rqsum=ii=0 ; ii < npt ; ii++ ){
775 resd[ii] -= rj[ii] * dg ; rqsum += resd[ii]*resd[ii] ;
776 }
777 rqsum *= npinv ;
778 }
779
780 } /*-- end of inner loop over parameters --*/
781
782 /**** test for convergence somehow ***/
783
784 if( ndel > 0 ) dsum *= (2.0f/ndel) ;
785
786 } /*---- end of outer iteration loop ----*/
787
788 #if 1
789 { AO_DEFINE_SCALAR(int,ncall) ;
790 if( AO_VALUE(ncall) < 2 || AO_VALUE(ncall)%10000 == 1 ){
791 for( nfree=jj=0 ; jj < nref ; jj++ ) nfree += (ppar[jj] != 0.0f) ;
792 INFO_message("SQRTLASSO %d: nite=%d dsum=%g dsumx=%g nfree=%d/%d",AO_VALUE(ncall),nite,dsum,dsumx,nfree,nref) ;
793 }
794 AO_VALUE(ncall)++ ;
795 }
796 #endif
797
798 /*--- if 'post' computation is ordered, re-do the
799 regression without constraints, but using only
800 the references with non-zero weights from above ---*/
801
802 if( do_post ){
803 nfree = 0 ;
804 for( jj=0 ; jj < nref ; jj++ ){ /* count and mark params to use */
805 fr[jj] = (ppar[jj] != 0.0f) ; nfree += fr[jj] ;
806 }
807 compute_free_param( npt,far,nref,ref,2,ccon , nfree,fr , ppar ) ;
808 }
809
810 /*--- Loading up the truck and heading to Beverlee ---*/
811
812 #pragma omp critical (MALLOC)
813 { free(qal) ; free(ain) ; free(fr) ; free(rsq) ; free(resd) ; free(mylam) ; }
814
815 RETURN(qfit) ;
816 }
817