1 #include "mrilib.h"
2
3 #undef TWOPI
4 #define TWOPI 6.283185307
5
6 typedef struct {
7 int nwt ; float a ; float *wt ;
8 } wtarray ;
9
10 #undef INIT_wtarray
11 #define INIT_wtarray(nam,aa,nn) \
12 do{ (nam) = (wtarray *)malloc(sizeof(wtarray)) ; \
13 (nam)->nwt = (nn) ; (nam)->a = (aa) ; \
14 (nam)->wt = ((nn) > 0) ? (float *)calloc((nn),sizeof(float)) \
15 : (float *)NULL ; \
16 } while(0)
17
18 #undef FREEIF
19 #define FREEIF(p) do{ if((p)!=NULL){free(p);(p)=NULL;} } while(0)
20
21 static MRI_IMAGE * mri_psinv( MRI_IMAGE *imc , float *wt ) ;
22
23 /*---------------------------------------------------------------------------*/
24
wtarray_inverse(int nnx,wtarray * wf,int nwi,wtarray * wi)25 float wtarray_inverse( int nnx , wtarray *wf , int nwi , wtarray *wi )
26 {
27 MRI_IMAGE *imc , *imp ;
28 float *car , *par , *wtf , *wti , *rhs ;
29 int nx=nnx , nwf , ii,jj ;
30 float dx , xx , ss , yy , aa,ainv ;
31 double esum ;
32
33 /** take care of pathological cases **/
34
35 if( wi == NULL || nx < 2 ) return 0.0f ;
36 if( wf == NULL || wf->a == 0.0f ){
37 ERROR_message("wf is bad") ;
38 wi->a = 1.0f ; wi->nwt = 0 ; FREEIF(wi->wt) ; return 0.0f ;
39 }
40 if( wf->nwt <= 0 ){
41 ERROR_message("wf->nwt is %d",wf->nwt) ;
42 wi->a = (wf->a != 0.0f) ? 1.0f/wf->a : 1.0f ;
43 wi->nwt = 0 ; FREEIF(wi->wt) ; return 0.0f ;
44 }
45 if( nwi < 1 ) nwi = wf->nwt ;
46
47 FREEIF(wi->wt) ;
48 aa = wf->a ; wi->a = ainv = 1.0f/aa ;
49 dx = TWOPI / (nx-1) ;
50 nwf = wf->nwt ; wi->nwt = nwi ;
51 wtf = wf->wt ;
52 wti = wi->wt = (float *)malloc(sizeof(float *)*nwi) ;
53 imc = mri_new( nx , nwi , MRI_float ) ;
54 car = MRI_FLOAT_PTR(imc) ;
55 rhs = (float *)malloc(sizeof(float)*nx) ;
56
57 /* basis functions are sin(2*k*Pi*x) for k=1..nw, x=0..1 */
58
59 for( ii=0 ; ii < nx ; ii++ ){
60 xx = dx * ii ; ss = 0.0f ;
61 for( jj=0 ; jj < nwf ; jj++ ) ss += wtf[jj] * sinf((jj+1)*xx) ;
62 rhs[ii] = -ainv * ss ;
63 yy = aa*xx + TWOPI*ss ;
64 for( jj=0 ; jj < nwi ; jj++ ) car[ii+jj*nx] = sinf((jj+1)*yy) ;
65 }
66
67 imp = mri_psinv( imc , NULL ) ; mri_free(imc) ;
68 par = MRI_FLOAT_PTR(imp) ;
69
70 for( jj=0 ; jj < nwi ; jj++ ){
71 ss = 0.0f ;
72 for( ii=0 ; ii < nx ; ii++ ) ss += par[jj+ii*nwi] * rhs[ii] ;
73 wti[jj] = ss ;
74 }
75 mri_free(imp) ; free((void *)rhs) ;
76
77 esum = 0.0 ;
78 for( ii=0 ; ii < nx ; ii++ ){
79 xx = dx * ii ; ss = 0.0f ;
80 for( jj=0 ; jj < nwf ; jj++ ) ss += wtf[jj] * sinf((jj+1)*xx) ;
81 yy = aa*xx + TWOPI*ss ;
82 ss = ainv*ss ;
83 for( jj=0 ; jj < nwi ; jj++ ) ss += wti[jj] * sinf((jj+1)*yy) ;
84 esum += ss*ss ;
85 }
86 return (float)sqrt(esum/nx) ;
87 }
88
89 /*-----------------------------------------------------------------------*/
90 /*! Compute the pseudo-inverse of a matrix stored in a 2D float image.
91 If the input is mXn, the output is nXm. wt[] is an optional array
92 of positive weights, m of them. The result can be used to solve
93 the weighted least squares problem
94 [imc] [b] = [v]
95 where [b] is an n-vector and [v] is an m-vector, where m > n.
96 -------------------------------------------------------------------------*/
97
mri_psinv(MRI_IMAGE * imc,float * wt)98 static MRI_IMAGE * mri_psinv( MRI_IMAGE *imc , float *wt )
99 {
100 float *rmat=MRI_FLOAT_PTR(imc) ;
101 int m=imc->nx , n=imc->ny , ii,jj,kk ;
102 double *amat , *umat , *vmat , *sval , *xfac , smax,del,ww ;
103 MRI_IMAGE *imp ; float *pmat ;
104 register double sum ;
105 int do_svd=0 ;
106
107 amat = (double *)calloc( sizeof(double),m*n ) ; /* input matrix */
108 xfac = (double *)calloc( sizeof(double),n ) ; /* column norms of [a] */
109
110 #define R(i,j) rmat[(i)+(j)*m] /* i=0..m-1 , j=0..n-1 */
111 #define A(i,j) amat[(i)+(j)*m] /* i=0..m-1 , j=0..n-1 */
112 #define P(i,j) pmat[(i)+(j)*n] /* i=0..n-1 , j=0..m-1 */
113
114 /* copy input matrix into amat */
115
116 for( ii=0 ; ii < m ; ii++ )
117 for( jj=0 ; jj < n ; jj++ ) A(ii,jj) = R(ii,jj) ;
118
119 /* weight rows? */
120
121 if( wt != NULL ){
122 for( ii=0 ; ii < m ; ii++ ){
123 ww = wt[ii] ;
124 for( jj=0 ; jj < n ; jj++ ) A(ii,jj) *= ww ;
125 }
126 }
127
128 /* scale each column to have norm 1 */
129
130 for( jj=0 ; jj < n ; jj++ ){
131 sum = 0.0 ;
132 for( ii=0 ; ii < m ; ii++ ) sum += A(ii,jj)*A(ii,jj) ;
133 if( sum > 0.0 ) sum = 1.0/sqrt(sum) ; else do_svd = 1 ;
134 xfac[jj] = sum ;
135 for( ii=0 ; ii < m ; ii++ ) A(ii,jj) *= sum ;
136 }
137
138 /*** compute using Choleski or SVD ***/
139
140 if( do_svd || AFNI_yesenv("AFNI_WARPDRIVE_SVD") ){ /***--- SVD method ---***/
141
142 #define U(i,j) umat[(i)+(j)*m]
143 #define V(i,j) vmat[(i)+(j)*n]
144
145 umat = (double *)calloc( sizeof(double),m*n ); /* left singular vectors */
146 vmat = (double *)calloc( sizeof(double),n*n ); /* right singular vectors */
147 sval = (double *)calloc( sizeof(double),n ); /* singular values */
148
149 /* compute SVD of scaled matrix */
150
151 svd_double( m , n , amat , sval , umat , vmat ) ;
152
153 free((void *)amat) ; /* done with this */
154
155 /* find largest singular value */
156
157 smax = sval[0] ;
158 for( ii=1 ; ii < n ; ii++ ) if( sval[ii] > smax ) smax = sval[ii] ;
159
160 if( smax <= 0.0 ){ /* this is bad */
161 fprintf(stderr,"** ERROR: SVD fails in mri_warp3D_align_setup!\n");
162 free((void *)xfac); free((void *)sval);
163 free((void *)vmat); free((void *)umat); return NULL;
164 }
165
166 for( ii=0 ; ii < n ; ii++ )
167 if( sval[ii] < 0.0 ) sval[ii] = 0.0 ; /* should not happen */
168
169 #define PSINV_EPS 1.e-8
170
171 /* "reciprocals" of singular values: 1/s is actually s/(s^2+del) */
172
173 del = PSINV_EPS * smax*smax ;
174 for( ii=0 ; ii < n ; ii++ )
175 sval[ii] = sval[ii] / ( sval[ii]*sval[ii] + del ) ;
176
177 /* create pseudo-inverse */
178
179 imp = mri_new( n , m , MRI_float ) ; /* recall that m > n */
180 pmat = MRI_FLOAT_PTR(imp) ;
181
182 for( ii=0 ; ii < n ; ii++ ){
183 for( jj=0 ; jj < m ; jj++ ){
184 sum = 0.0 ;
185 for( kk=0 ; kk < n ; kk++ ) sum += sval[kk] * V(ii,kk) * U(jj,kk) ;
186 P(ii,jj) = (float)sum ;
187 }
188 }
189 free((void *)sval); free((void *)vmat); free((void *)umat);
190
191 } else { /***----- Choleski method -----***/
192
193 vmat = (double *)calloc( sizeof(double),n*n ); /* normal matrix */
194
195 for( ii=0 ; ii < n ; ii++ ){
196 for( jj=0 ; jj <= ii ; jj++ ){
197 sum = 0.0 ;
198 for( kk=0 ; kk < m ; kk++ ) sum += A(kk,ii) * A(kk,jj) ;
199 V(ii,jj) = sum ;
200 }
201 V(ii,ii) += PSINV_EPS ; /* note V(ii,ii)==1 before this */
202 }
203
204 #if 0
205 fprintf(stderr,"NORMAL MATRIX::\n") ;
206 for( ii=0 ; ii < n ; ii++ ){
207 fprintf(stderr,"%2d:",ii) ;
208 for( jj=0 ; jj <= ii ; jj++ ) fprintf(stderr," %7.4f",V(ii,jj)) ;
209 fprintf(stderr,"\n") ;
210 }
211 #endif
212 #if 1
213 { double rr,rmax=0.0 ;
214 for( ii=0 ; ii < n ; ii++ ){
215 rr = 0.0 ;
216 for( jj=0 ; jj < n ; jj++ ){
217 if( jj < ii ) rr += fabs(V(ii,jj)) ;
218 else if( jj > ii ) rr += fabs(V(jj,ii)) ;
219 }
220 rr = rr / V(ii,ii) ; if( rr > rmax ) rmax = rr ;
221 }
222 fprintf(stderr,"MAX row ratio = %g\n",rmax) ;
223 }
224 #endif
225
226 /* Choleski factor */
227
228 for( ii=0 ; ii < n ; ii++ ){
229 for( jj=0 ; jj < ii ; jj++ ){
230 sum = V(ii,jj) ;
231 for( kk=0 ; kk < jj ; kk++ ) sum -= V(ii,kk) * V(jj,kk) ;
232 V(ii,jj) = sum / V(jj,jj) ;
233 }
234 sum = V(ii,ii) ;
235 for( kk=0 ; kk < ii ; kk++ ) sum -= V(ii,kk) * V(ii,kk) ;
236 if( sum <= 0.0 ){
237 fprintf(stderr,"** ERROR: Choleski fails in mri_warp3D_align_setup!\n");
238 free((void *)xfac); free((void *)amat); free((void *)vmat); return NULL ;
239 }
240 V(ii,ii) = sqrt(sum) ;
241 }
242
243 #if 0
244 fprintf(stderr,"CHOLESKI FACTOR::\n") ;
245 for( ii=0 ; ii < n ; ii++ ){
246 fprintf(stderr,"%2d:",ii) ;
247 for( jj=0 ; jj <= ii ; jj++ ) fprintf(stderr," %7.4f",V(ii,jj)) ;
248 fprintf(stderr,"\n") ;
249 }
250 #endif
251
252 /* create pseudo-inverse */
253
254 imp = mri_new( n , m , MRI_float ) ; /* recall that m > n */
255 pmat = MRI_FLOAT_PTR(imp) ;
256
257 sval = (double *)calloc( sizeof(double),n ) ; /* row #jj of A */
258
259 for( jj=0 ; jj < m ; jj++ ){
260 for( ii=0 ; ii < n ; ii++ ) sval[ii] = A(jj,ii) ; /* extract row */
261
262 for( ii=0 ; ii < n ; ii++ ){ /* forward solve */
263 sum = sval[ii] ;
264 for( kk=0 ; kk < ii ; kk++ ) sum -= V(ii,kk) * sval[kk] ;
265 sval[ii] = sum / V(ii,ii) ;
266 }
267 for( ii=n-1 ; ii >= 0 ; ii-- ){ /* backward solve */
268 sum = sval[ii] ;
269 for( kk=ii+1 ; kk < n ; kk++ ) sum -= V(kk,ii) * sval[kk] ;
270 sval[ii] = sum / V(ii,ii) ;
271 }
272
273 for( ii=0 ; ii < n ; ii++ ) P(ii,jj) = (float)sval[ii] ;
274 }
275 free((void *)amat); free((void *)vmat); free((void *)sval);
276 }
277
278 /* rescale rows from norming */
279
280 for( ii=0 ; ii < n ; ii++ ){
281 for( jj=0 ; jj < m ; jj++ ) P(ii,jj) *= xfac[ii] ;
282 }
283 free((void *)xfac);
284
285 /* rescale cols for weight? */
286
287 if( wt != NULL ){
288 for( ii=0 ; ii < m ; ii++ ){
289 ww = wt[ii] ;
290 for( jj=0 ; jj < n ; jj++ ) P(jj,ii) *= ww ;
291 }
292 }
293
294 return imp;
295 }
296
297 /*-----------------------------------------------------------------------*/
298
299 #define NW 10
300
main(int argc,char * argv[])301 int main( int argc , char *argv[] )
302 {
303 wtarray *wt_for , *wt_inv ;
304 int jj ;
305 int nwf=NW , nwi=0 ;
306 float err , aa=0.222 ;
307
308 if( argc > 1 ){
309 nwf = (int)strtod(argv[1],NULL) ;
310 if( nwf < 1 || nwf > 99 ) nwf = NW ;
311 }
312 if( argc > 2 ){
313 nwi = (int)strtod(argv[2],NULL) ;
314 if( nwi < 1 || nwi > 999 ) nwi = 0 ;
315 }
316 if( nwi < 1 ) nwi = nwf ;
317 if( argc > 3 ){
318 aa = (float)strtod(argv[3],NULL) ;
319 }
320
321 INIT_wtarray(wt_for,1.0f,nwf) ;
322 INIT_wtarray(wt_inv,1.0f,0 ) ;
323
324 for( jj=1 ; jj <= nwf ; jj++ )
325 wt_for->wt[jj-1] = aa /(jj*jj + 1.0f) ;
326
327 err = wtarray_inverse( 66*MAX(nwf,nwi) , wt_for , nwi , wt_inv ) ;
328
329 if( wt_inv->nwt == 0 || wt_inv->wt == NULL ) ERROR_exit("Bad wt_inv!") ;
330
331 printf("err = %.5g\n",err) ;
332
333 printf("z := x -> %f * x ",wt_inv->a) ;
334 for( jj=1 ; jj <= nwi ; jj++ ){
335 if( wt_inv->wt[jj-1] != 0.0f ){
336 if( wt_inv->wt[jj-1] > 0.0f ) printf(" +") ;
337 printf("%f * sin(%d*Pi*x)" , wt_inv->wt[jj-1] , 2*jj ) ;
338 }
339 }
340 printf(";\n") ;
341
342 printf("y := x -> %f * x ",wt_for->a) ;
343 for( jj=1 ; jj <= nwf ; jj++ ){
344 if( wt_for->wt[jj-1] != 0.0f ){
345 if( wt_for->wt[jj-1] > 0.0f ) printf(" +") ;
346 printf("%f * sin(%d*Pi*x)" , wt_for->wt[jj-1] , 2*jj ) ;
347 }
348 }
349 printf(";\n") ;
350
351 exit(0) ;
352 }
353