1 #ifndef _ltfat_dgtrealmp_private_h
2 #define _ltfat_dgtrealmp_private_h
3 #include "ltfat.h"
4 #include "ltfat/types.h"
5 #include "ltfat/macros.h"
6 
LTFAT_NAME(dgtrealmp_parbuf)7 struct LTFAT_NAME(dgtrealmp_parbuf)
8 {
9     LTFAT_REAL**        g;
10     ltfat_int*          gl;
11     ltfat_int*          a;
12     ltfat_int*          M;
13     int*         chanmask;
14     ltfat_int           P;
15     ltfat_dgtmp_params* params;
16     LTFAT_NAME(dgtrealmp_iterstep_callback)* iterstepcallback;
17     void*                        iterstepcallbackdata;
18 //    LTFAT_REAL          chirprate;
19 //    LTFAT_REAL          shiftby;
20 };
21 
22 struct ltfat_dgtmp_params
23 {
24     // ltfat_dgtrealmp_hint  hint;
25     ltfat_dgtmp_alg       alg;
26     double                atprodreltoldb;
27     double                atprodreltoladj;
28     long double           errtoldb;
29     long double           errtoladj;
30     double                kernrelthr;
31     size_t                maxit;
32     size_t                maxatoms;
33     size_t                iterstep;
34     int                   verbose;
35     int                   initwasrun;
36     int                   treelevels;
37     size_t                cycles;
38     ltfat_phaseconvention ptype;
39     int                   do_pedantic;
40 };
41 
42 typedef struct
43 {
44     ltfat_int height;
45     ltfat_int width;
46 } ksize;
47 
48 typedef struct
49 {
50     ltfat_int hmid;
51     ltfat_int wmid;
52 } kanchor;
53 
54 typedef struct
55 {
56     ltfat_int start;
57     ltfat_int end;
58 } krange;
59 
60 typedef struct
61 {
62     ltfat_int m;
63     ltfat_int n;
64     ltfat_int w;
65     ltfat_int n2;
66 } kpoint;
67 #define PTOI(k) k.w][k.m + p->M2[k.w] * k.n
68 #define kpoint_init(m,n,w) LTFAT_STRUCTINIT(kpoint,m,n,w,n)
69 #define kpoint_init2(m,n,n2,w) LTFAT_STRUCTINIT(kpoint,m,n,w,n2)
70 #define kpoint_isequal(k1,k2) (k1.m == k2.m && k1.n == k2.n && k1.w == k2.w)
71 
72 
73 typedef struct
74 {
75     ksize              size;
76     kanchor             mid;
77     ltfat_int           kNo;
78     ltfat_int         kSkip;
79     LTFAT_COMPLEX**    mods;
80     LTFAT_COMPLEX*     kval;
81     krange*           range;
82     krange*          srange;
83     LTFAT_REAL       absthr;
84     double             Mrat;
85     double             arat;
86     ltfat_int         Mstep;
87     ltfat_int         astep;
88     LTFAT_COMPLEX*     atprods;
89     LTFAT_REAL* oneover1minatprodnorms;
90     ltfat_int   atprodsNo;
91     int cloned;
92      ltfat_phaseconvention ptype;
93 } LTFAT_NAME(kerns);
94 
95 
96 typedef struct
97 {
98     LTFAT_COMPLEX**        c;
99     LTFAT_REAL**           maxcols;
100     ltfat_int**            maxcolspos;
101     LTFAT_NAME(maxtree)**  tmaxtree;
102     LTFAT_NAME(maxtree)*** fmaxtree;
103     unsigned int**         suppind;
104     long double            err;
105     long double            fnorm2;
106     size_t                 currit;
107     size_t                 curratoms;
108     ltfat_int              P;
109     ltfat_int*             N;
110     LTFAT_COMPLEX**        cvalModBuf;
111     // LocOMP related
112     LTFAT_COMPLEX*         gramBuf;
113     LTFAT_COMPLEX*         cvalBuf;
114     LTFAT_COMPLEX*         cvalinvBuf;
115     kpoint*                cvalBufPos;
116     LTFAT_NAME_COMPLEX(hermsystemsolver_plan)* hplan;
117     // CyclicMP related
118     kpoint*                pBuf;
119     size_t                 pBufSize;
120     size_t                 pBufNo;
121 } LTFAT_NAME(dgtrealmpiter_state);
122 
123 
124 typedef struct
125 {
126     ltfat_int n;
127     ltfat_int w;
128     LTFAT_NAME(dgtrealmp_state)* state;
129 } LTFAT_NAME(dgtrealmp_state_closure);
130 
LTFAT_NAME(dgtrealmp_state)131 struct LTFAT_NAME(dgtrealmp_state)
132 {
133     LTFAT_NAME(dgtrealmpiter_state)* iterstate;
134     LTFAT_NAME(kerns)**             gramkerns; // PxP plans
135     LTFAT_NAME(dgtreal_plan)**       dgtplans;  // P plans
136     ltfat_int*        a;
137     ltfat_int*        M;
138     ltfat_int*       M2;
139     ltfat_int*        N;
140     int*       chanmask;
141     ltfat_int         P;
142     ltfat_int         L;
143     ltfat_dgtmp_params* params;
144     LTFAT_COMPLEX**     couttmp;
145     LTFAT_NAME(dgtrealmp_state_closure)** closures;
146     LTFAT_NAME(dgtrealmp_iterstep_callback)* callback;
147     void* userdata;
148 };
149 
150 static inline LTFAT_REAL
ltfat_norm(LTFAT_COMPLEX c)151 ltfat_norm(LTFAT_COMPLEX c)
152 {
153     return ltfat_real(c) * ltfat_real(c) + ltfat_imag(c) * ltfat_imag(c);
154 }
155 
156 static inline LTFAT_REAL
LTFAT_NAME(dgtrealmp_execute_projenergy)157 LTFAT_NAME(dgtrealmp_execute_projenergy)(
158     LTFAT_COMPLEX atinprod, LTFAT_COMPLEX cval)
159 {
160     LTFAT_REAL cr = ltfat_real(cval);
161     LTFAT_REAL ci = ltfat_imag(cval);
162     LTFAT_REAL cr2 = cr*cr;
163     LTFAT_REAL ci2 = ci*ci;
164 	LTFAT_REAL two = (LTFAT_REAL) 2.0;
165     return two*(cr2 + ci2 + ltfat_real(atinprod)*(cr2 - ci2) - two*ltfat_imag(atinprod)*cr*ci);
166 }
167 
168 static inline LTFAT_REAL
LTFAT_NAME(dgtrealmp_execute_dualprojenergy)169 LTFAT_NAME(dgtrealmp_execute_dualprojenergy)(
170     LTFAT_COMPLEX atinprod, LTFAT_REAL oneoveroneminatprodnorm, LTFAT_COMPLEX cval)
171 {
172     LTFAT_COMPLEX cvaldual = (cval - (atinprod) * conj(cval)) * oneoveroneminatprodnorm;
173     return LTFAT_NAME(dgtrealmp_execute_projenergy)( atinprod, cvaldual);
174 }
175 
176 /* BEGIN_C_DECLS */
177 #ifdef __cplusplus
178 extern "C" {
179 #endif
180 
181 int
182 LTFAT_NAME(dgtrealmpiter_init)(
183     ltfat_int a[], ltfat_int M[], ltfat_int P, ltfat_int L,
184     LTFAT_NAME(dgtrealmpiter_state)** state);
185 
186 int
187 LTFAT_NAME(dgtrealmpiter_done)(LTFAT_NAME(dgtrealmpiter_state)** state);
188 
189 int
190 LTFAT_NAME(dgtrealmp_kernel_cloneconj)(
191     LTFAT_NAME(kerns)* kin, LTFAT_NAME(kerns)** kout);
192 
193 int
194 LTFAT_NAME(dgtrealmp_kernel_init)(
195     const LTFAT_REAL* g[], ltfat_int gl[], ltfat_int a[], ltfat_int M[],
196     ltfat_int L, LTFAT_REAL reltol, ltfat_phaseconvention ptype,
197     LTFAT_NAME(kerns)** pout);
198 
199 int
200 LTFAT_NAME(dgtrealmp_kernel_done)(LTFAT_NAME(kerns)** k);
201 
202 int
203 LTFAT_NAME(dgtrealmp_kernel_modfi)(
204     const LTFAT_COMPLEX* kfirst, ksize size, kanchor mid, ltfat_int n, ltfat_int a, ltfat_int M,
205     LTFAT_COMPLEX* kmod);
206 
207 int
208 LTFAT_NAME(dgtrealmp_kernel_modti)(
209     const LTFAT_COMPLEX* kfirst, ksize size, kanchor mid, ltfat_int m, ltfat_int a, ltfat_int M,
210     LTFAT_COMPLEX* kmod);
211 
212 int
213 LTFAT_NAME(dgtrealmp_kernel_modfiexp)(
214     ksize size, kanchor mid, ltfat_int n, ltfat_int a, ltfat_int M,
215     LTFAT_COMPLEX* kmod);
216 
217 int
218 LTFAT_NAME(dgtrealmp_kernel_modtiexp)(
219     ksize size, kanchor mid, ltfat_int m, ltfat_int a, ltfat_int M,
220     LTFAT_COMPLEX* kmod);
221 
222 int
223 LTFAT_NAME(dgtrealmp_kernel_findsmallsize)(
224     const LTFAT_COMPLEX kernlarge[], ltfat_int M, ltfat_int N,
225     LTFAT_REAL reltol, LTFAT_REAL* absthr, ksize* size, kanchor* anchor);
226 
227 int
228 LTFAT_NAME(dgtrealmp_essentialsupport)(
229     const LTFAT_REAL g[], ltfat_int gl, LTFAT_REAL reltol,
230     ltfat_int* lefttail, ltfat_int* righttail);
231 
232 int
233 LTFAT_NAME(dgtrealmp_execute_kpos)(
234     LTFAT_NAME(dgtrealmp_state)* p, kpoint pos1, kpoint pos2,
235     ltfat_int* m2, ltfat_int* n2, ltfat_int* Mstep, ltfat_int* astep,
236     ksize* kdim2, kanchor* kmid2, kpoint* kstart2);
237 
238 int
239 LTFAT_NAME(dgtrealmp_execute_indices)(
240     LTFAT_NAME(dgtrealmp_state)* p, kpoint origpos, kpoint* pos,
241     ltfat_int* m2start, ltfat_int* n2start, ksize* kdim2, kanchor* kmid2,
242     kpoint* kstart2);
243 
244 LTFAT_COMPLEX*
245 LTFAT_NAME(dgtrealmp_execute_pickkernel)(
246     LTFAT_NAME(kerns)* currkern, ltfat_int m, ltfat_int n,
247     ltfat_phaseconvention pconv);
248 
249 LTFAT_COMPLEX*
250 LTFAT_NAME(dgtrealmp_execute_pickmod)(
251     LTFAT_NAME(kerns)* currkern, ltfat_int m, ltfat_int n,
252     ltfat_phaseconvention pconv);
253 
254 int
255 LTFAT_NAME(dgtrealmp_execute_findmaxatom)(
256     LTFAT_NAME(dgtrealmp_state)* p, kpoint* pos);
257 // ltfat_int* m, ltfat_int* n, ltfat_int* w);
258 
259 int
260 LTFAT_NAME(dgtrealmp_execute_findneighbors)(
261         LTFAT_NAME(dgtrealmp_state)* p, kpoint pos,
262         kpoint* nBuf, size_t* nCount);
263 
264 int
265 LTFAT_NAME(dgtrealmp_execute_updateresiduum)(
266     LTFAT_NAME(dgtrealmp_state)* p, kpoint pos, LTFAT_COMPLEX cval,
267     int do_substract);
268 
269 LTFAT_REAL
270 LTFAT_NAME(dgtrealmp_execute_atenergy)(
271     LTFAT_COMPLEX ainprod, LTFAT_COMPLEX cval);
272 
273 LTFAT_REAL
274 LTFAT_NAME(dgtrealmp_execute_adjustedenergy)(
275     LTFAT_NAME(dgtrealmp_state)* p, kpoint pos, LTFAT_COMPLEX cval);
276 
277 void
278 LTFAT_NAME(dgtrealmp_execute_conjatpairprod)(
279     LTFAT_NAME(dgtrealmp_state)* p, kpoint pos,
280     LTFAT_COMPLEX* atinprod, LTFAT_REAL* oneover1minatprodnorm);
281 
282 void
283 LTFAT_NAME(dgtrealmp_execute_dualprodandprojenergy)(
284     LTFAT_NAME(dgtrealmp_state)* p, kpoint pos, LTFAT_COMPLEX cval,
285     LTFAT_COMPLEX* cvaldual, LTFAT_REAL* projenergy);
286 
287 LTFAT_REAL
288 LTFAT_NAME(dgtrealmp_execute_mp)(
289     LTFAT_NAME(dgtrealmp_state)* p, LTFAT_COMPLEX cval,
290     kpoint pos, LTFAT_COMPLEX** cout);
291 
292 int
293 LTFAT_NAME(dgtrealmp_execute_cyclicmp)(
294     LTFAT_NAME(dgtrealmp_state)* p,
295     kpoint origpos, LTFAT_COMPLEX** cout);
296 
297 int
298 LTFAT_NAME(dgtrealmp_execute_selfprojmp)(
299     LTFAT_NAME(dgtrealmp_state)* p,
300     kpoint origpos, LTFAT_COMPLEX** cout);
301 
302 int
303 LTFAT_NAME(dgtrealmp_execute_locomp)(
304     LTFAT_NAME(dgtrealmp_state)* p,
305     kpoint origpos, LTFAT_COMPLEX** cout);
306 
307 LTFAT_REAL
308 LTFAT_NAME(dgtrealmp_execute_invmp)(
309     LTFAT_NAME(dgtrealmp_state)* p,
310     kpoint pos, LTFAT_COMPLEX** cout);
311 
312 LTFAT_REAL
313 LTFAT_NAME(dgtrealmp_execute_realatenergy)(
314     LTFAT_NAME(dgtrealmp_state)* p,
315     kpoint pos, LTFAT_COMPLEX cval);
316 
317 LTFAT_REAL
318 LTFAT_NAME(pedantic_callback)(void* userdata,
319                               LTFAT_COMPLEX cval, ltfat_int pos);
320 #ifdef __cplusplus
321 }  // extern "C"
322 #endif
323 
324 #endif
325