1 #ifndef _ltfat_dgtrealmp_private_h
2 #define _ltfat_dgtrealmp_private_h
3 #include "ltfat.h"
4 #include "ltfat/types.h"
5 #include "ltfat/macros.h"
6
LTFAT_NAME(dgtrealmp_parbuf)7 struct LTFAT_NAME(dgtrealmp_parbuf)
8 {
9 LTFAT_REAL** g;
10 ltfat_int* gl;
11 ltfat_int* a;
12 ltfat_int* M;
13 int* chanmask;
14 ltfat_int P;
15 ltfat_dgtmp_params* params;
16 LTFAT_NAME(dgtrealmp_iterstep_callback)* iterstepcallback;
17 void* iterstepcallbackdata;
18 // LTFAT_REAL chirprate;
19 // LTFAT_REAL shiftby;
20 };
21
22 struct ltfat_dgtmp_params
23 {
24 // ltfat_dgtrealmp_hint hint;
25 ltfat_dgtmp_alg alg;
26 double atprodreltoldb;
27 double atprodreltoladj;
28 long double errtoldb;
29 long double errtoladj;
30 double kernrelthr;
31 size_t maxit;
32 size_t maxatoms;
33 size_t iterstep;
34 int verbose;
35 int initwasrun;
36 int treelevels;
37 size_t cycles;
38 ltfat_phaseconvention ptype;
39 int do_pedantic;
40 };
41
42 typedef struct
43 {
44 ltfat_int height;
45 ltfat_int width;
46 } ksize;
47
48 typedef struct
49 {
50 ltfat_int hmid;
51 ltfat_int wmid;
52 } kanchor;
53
54 typedef struct
55 {
56 ltfat_int start;
57 ltfat_int end;
58 } krange;
59
60 typedef struct
61 {
62 ltfat_int m;
63 ltfat_int n;
64 ltfat_int w;
65 ltfat_int n2;
66 } kpoint;
67 #define PTOI(k) k.w][k.m + p->M2[k.w] * k.n
68 #define kpoint_init(m,n,w) LTFAT_STRUCTINIT(kpoint,m,n,w,n)
69 #define kpoint_init2(m,n,n2,w) LTFAT_STRUCTINIT(kpoint,m,n,w,n2)
70 #define kpoint_isequal(k1,k2) (k1.m == k2.m && k1.n == k2.n && k1.w == k2.w)
71
72
73 typedef struct
74 {
75 ksize size;
76 kanchor mid;
77 ltfat_int kNo;
78 ltfat_int kSkip;
79 LTFAT_COMPLEX** mods;
80 LTFAT_COMPLEX* kval;
81 krange* range;
82 krange* srange;
83 LTFAT_REAL absthr;
84 double Mrat;
85 double arat;
86 ltfat_int Mstep;
87 ltfat_int astep;
88 LTFAT_COMPLEX* atprods;
89 LTFAT_REAL* oneover1minatprodnorms;
90 ltfat_int atprodsNo;
91 int cloned;
92 ltfat_phaseconvention ptype;
93 } LTFAT_NAME(kerns);
94
95
96 typedef struct
97 {
98 LTFAT_COMPLEX** c;
99 LTFAT_REAL** maxcols;
100 ltfat_int** maxcolspos;
101 LTFAT_NAME(maxtree)** tmaxtree;
102 LTFAT_NAME(maxtree)*** fmaxtree;
103 unsigned int** suppind;
104 long double err;
105 long double fnorm2;
106 size_t currit;
107 size_t curratoms;
108 ltfat_int P;
109 ltfat_int* N;
110 LTFAT_COMPLEX** cvalModBuf;
111 // LocOMP related
112 LTFAT_COMPLEX* gramBuf;
113 LTFAT_COMPLEX* cvalBuf;
114 LTFAT_COMPLEX* cvalinvBuf;
115 kpoint* cvalBufPos;
116 LTFAT_NAME_COMPLEX(hermsystemsolver_plan)* hplan;
117 // CyclicMP related
118 kpoint* pBuf;
119 size_t pBufSize;
120 size_t pBufNo;
121 } LTFAT_NAME(dgtrealmpiter_state);
122
123
124 typedef struct
125 {
126 ltfat_int n;
127 ltfat_int w;
128 LTFAT_NAME(dgtrealmp_state)* state;
129 } LTFAT_NAME(dgtrealmp_state_closure);
130
LTFAT_NAME(dgtrealmp_state)131 struct LTFAT_NAME(dgtrealmp_state)
132 {
133 LTFAT_NAME(dgtrealmpiter_state)* iterstate;
134 LTFAT_NAME(kerns)** gramkerns; // PxP plans
135 LTFAT_NAME(dgtreal_plan)** dgtplans; // P plans
136 ltfat_int* a;
137 ltfat_int* M;
138 ltfat_int* M2;
139 ltfat_int* N;
140 int* chanmask;
141 ltfat_int P;
142 ltfat_int L;
143 ltfat_dgtmp_params* params;
144 LTFAT_COMPLEX** couttmp;
145 LTFAT_NAME(dgtrealmp_state_closure)** closures;
146 LTFAT_NAME(dgtrealmp_iterstep_callback)* callback;
147 void* userdata;
148 };
149
150 static inline LTFAT_REAL
ltfat_norm(LTFAT_COMPLEX c)151 ltfat_norm(LTFAT_COMPLEX c)
152 {
153 return ltfat_real(c) * ltfat_real(c) + ltfat_imag(c) * ltfat_imag(c);
154 }
155
156 static inline LTFAT_REAL
LTFAT_NAME(dgtrealmp_execute_projenergy)157 LTFAT_NAME(dgtrealmp_execute_projenergy)(
158 LTFAT_COMPLEX atinprod, LTFAT_COMPLEX cval)
159 {
160 LTFAT_REAL cr = ltfat_real(cval);
161 LTFAT_REAL ci = ltfat_imag(cval);
162 LTFAT_REAL cr2 = cr*cr;
163 LTFAT_REAL ci2 = ci*ci;
164 LTFAT_REAL two = (LTFAT_REAL) 2.0;
165 return two*(cr2 + ci2 + ltfat_real(atinprod)*(cr2 - ci2) - two*ltfat_imag(atinprod)*cr*ci);
166 }
167
168 static inline LTFAT_REAL
LTFAT_NAME(dgtrealmp_execute_dualprojenergy)169 LTFAT_NAME(dgtrealmp_execute_dualprojenergy)(
170 LTFAT_COMPLEX atinprod, LTFAT_REAL oneoveroneminatprodnorm, LTFAT_COMPLEX cval)
171 {
172 LTFAT_COMPLEX cvaldual = (cval - (atinprod) * conj(cval)) * oneoveroneminatprodnorm;
173 return LTFAT_NAME(dgtrealmp_execute_projenergy)( atinprod, cvaldual);
174 }
175
176 /* BEGIN_C_DECLS */
177 #ifdef __cplusplus
178 extern "C" {
179 #endif
180
181 int
182 LTFAT_NAME(dgtrealmpiter_init)(
183 ltfat_int a[], ltfat_int M[], ltfat_int P, ltfat_int L,
184 LTFAT_NAME(dgtrealmpiter_state)** state);
185
186 int
187 LTFAT_NAME(dgtrealmpiter_done)(LTFAT_NAME(dgtrealmpiter_state)** state);
188
189 int
190 LTFAT_NAME(dgtrealmp_kernel_cloneconj)(
191 LTFAT_NAME(kerns)* kin, LTFAT_NAME(kerns)** kout);
192
193 int
194 LTFAT_NAME(dgtrealmp_kernel_init)(
195 const LTFAT_REAL* g[], ltfat_int gl[], ltfat_int a[], ltfat_int M[],
196 ltfat_int L, LTFAT_REAL reltol, ltfat_phaseconvention ptype,
197 LTFAT_NAME(kerns)** pout);
198
199 int
200 LTFAT_NAME(dgtrealmp_kernel_done)(LTFAT_NAME(kerns)** k);
201
202 int
203 LTFAT_NAME(dgtrealmp_kernel_modfi)(
204 const LTFAT_COMPLEX* kfirst, ksize size, kanchor mid, ltfat_int n, ltfat_int a, ltfat_int M,
205 LTFAT_COMPLEX* kmod);
206
207 int
208 LTFAT_NAME(dgtrealmp_kernel_modti)(
209 const LTFAT_COMPLEX* kfirst, ksize size, kanchor mid, ltfat_int m, ltfat_int a, ltfat_int M,
210 LTFAT_COMPLEX* kmod);
211
212 int
213 LTFAT_NAME(dgtrealmp_kernel_modfiexp)(
214 ksize size, kanchor mid, ltfat_int n, ltfat_int a, ltfat_int M,
215 LTFAT_COMPLEX* kmod);
216
217 int
218 LTFAT_NAME(dgtrealmp_kernel_modtiexp)(
219 ksize size, kanchor mid, ltfat_int m, ltfat_int a, ltfat_int M,
220 LTFAT_COMPLEX* kmod);
221
222 int
223 LTFAT_NAME(dgtrealmp_kernel_findsmallsize)(
224 const LTFAT_COMPLEX kernlarge[], ltfat_int M, ltfat_int N,
225 LTFAT_REAL reltol, LTFAT_REAL* absthr, ksize* size, kanchor* anchor);
226
227 int
228 LTFAT_NAME(dgtrealmp_essentialsupport)(
229 const LTFAT_REAL g[], ltfat_int gl, LTFAT_REAL reltol,
230 ltfat_int* lefttail, ltfat_int* righttail);
231
232 int
233 LTFAT_NAME(dgtrealmp_execute_kpos)(
234 LTFAT_NAME(dgtrealmp_state)* p, kpoint pos1, kpoint pos2,
235 ltfat_int* m2, ltfat_int* n2, ltfat_int* Mstep, ltfat_int* astep,
236 ksize* kdim2, kanchor* kmid2, kpoint* kstart2);
237
238 int
239 LTFAT_NAME(dgtrealmp_execute_indices)(
240 LTFAT_NAME(dgtrealmp_state)* p, kpoint origpos, kpoint* pos,
241 ltfat_int* m2start, ltfat_int* n2start, ksize* kdim2, kanchor* kmid2,
242 kpoint* kstart2);
243
244 LTFAT_COMPLEX*
245 LTFAT_NAME(dgtrealmp_execute_pickkernel)(
246 LTFAT_NAME(kerns)* currkern, ltfat_int m, ltfat_int n,
247 ltfat_phaseconvention pconv);
248
249 LTFAT_COMPLEX*
250 LTFAT_NAME(dgtrealmp_execute_pickmod)(
251 LTFAT_NAME(kerns)* currkern, ltfat_int m, ltfat_int n,
252 ltfat_phaseconvention pconv);
253
254 int
255 LTFAT_NAME(dgtrealmp_execute_findmaxatom)(
256 LTFAT_NAME(dgtrealmp_state)* p, kpoint* pos);
257 // ltfat_int* m, ltfat_int* n, ltfat_int* w);
258
259 int
260 LTFAT_NAME(dgtrealmp_execute_findneighbors)(
261 LTFAT_NAME(dgtrealmp_state)* p, kpoint pos,
262 kpoint* nBuf, size_t* nCount);
263
264 int
265 LTFAT_NAME(dgtrealmp_execute_updateresiduum)(
266 LTFAT_NAME(dgtrealmp_state)* p, kpoint pos, LTFAT_COMPLEX cval,
267 int do_substract);
268
269 LTFAT_REAL
270 LTFAT_NAME(dgtrealmp_execute_atenergy)(
271 LTFAT_COMPLEX ainprod, LTFAT_COMPLEX cval);
272
273 LTFAT_REAL
274 LTFAT_NAME(dgtrealmp_execute_adjustedenergy)(
275 LTFAT_NAME(dgtrealmp_state)* p, kpoint pos, LTFAT_COMPLEX cval);
276
277 void
278 LTFAT_NAME(dgtrealmp_execute_conjatpairprod)(
279 LTFAT_NAME(dgtrealmp_state)* p, kpoint pos,
280 LTFAT_COMPLEX* atinprod, LTFAT_REAL* oneover1minatprodnorm);
281
282 void
283 LTFAT_NAME(dgtrealmp_execute_dualprodandprojenergy)(
284 LTFAT_NAME(dgtrealmp_state)* p, kpoint pos, LTFAT_COMPLEX cval,
285 LTFAT_COMPLEX* cvaldual, LTFAT_REAL* projenergy);
286
287 LTFAT_REAL
288 LTFAT_NAME(dgtrealmp_execute_mp)(
289 LTFAT_NAME(dgtrealmp_state)* p, LTFAT_COMPLEX cval,
290 kpoint pos, LTFAT_COMPLEX** cout);
291
292 int
293 LTFAT_NAME(dgtrealmp_execute_cyclicmp)(
294 LTFAT_NAME(dgtrealmp_state)* p,
295 kpoint origpos, LTFAT_COMPLEX** cout);
296
297 int
298 LTFAT_NAME(dgtrealmp_execute_selfprojmp)(
299 LTFAT_NAME(dgtrealmp_state)* p,
300 kpoint origpos, LTFAT_COMPLEX** cout);
301
302 int
303 LTFAT_NAME(dgtrealmp_execute_locomp)(
304 LTFAT_NAME(dgtrealmp_state)* p,
305 kpoint origpos, LTFAT_COMPLEX** cout);
306
307 LTFAT_REAL
308 LTFAT_NAME(dgtrealmp_execute_invmp)(
309 LTFAT_NAME(dgtrealmp_state)* p,
310 kpoint pos, LTFAT_COMPLEX** cout);
311
312 LTFAT_REAL
313 LTFAT_NAME(dgtrealmp_execute_realatenergy)(
314 LTFAT_NAME(dgtrealmp_state)* p,
315 kpoint pos, LTFAT_COMPLEX cval);
316
317 LTFAT_REAL
318 LTFAT_NAME(pedantic_callback)(void* userdata,
319 LTFAT_COMPLEX cval, ltfat_int pos);
320 #ifdef __cplusplus
321 } // extern "C"
322 #endif
323
324 #endif
325