1 /**************************************************************************************************
2 * *
3 * This file is part of HPIPM. *
4 * *
5 * HPIPM -- High-Performance Interior Point Method. *
6 * Copyright (C) 2017-2018 by Gianluca Frison. *
7 * Developed at IMTEK (University of Freiburg) under the supervision of Moritz Diehl. *
8 * All rights reserved. *
9 * *
10 * This program is free software: you can redistribute it and/or modify *
11 * it under the terms of the GNU General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, or *
13 * (at your option) any later version *.
14 * *
15 * This program is distributed in the hope that it will be useful, *
16 * but WITHOUT ANY WARRANTY; without even the implied warranty of *
17 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
18 * GNU General Public License for more details. *
19 * *
20 * You should have received a copy of the GNU General Public License *
21 * along with this program. If not, see <https://www.gnu.org/licenses/>. *
22 * *
23 * The authors designate this particular file as subject to the "Classpath" exception *
24 * as provided by the authors in the LICENSE file that accompained this code. *
25 * *
26 * Author: Gianluca Frison, gianluca.frison (at) imtek.uni-freiburg.de *
27 * *
28 **************************************************************************************************/
29
30
31
32 #if defined(RUNTIME_CHECKS)
33 #include <stdlib.h>
34 #include <stdio.h>
35 #endif
36
37 #include <blasfeo_target.h>
38 #include <blasfeo_common.h>
39 #include <blasfeo_d_aux.h>
40 #include <blasfeo_d_blas.h>
41
42 #include "../include/hpipm_d_rk_int.h"
43 #include "../include/hpipm_d_erk_int.h"
44
45
46
d_memsize_erk_int(struct d_erk_arg * erk_arg,int nx,int np,int nf_max,int na_max)47 int d_memsize_erk_int(struct d_erk_arg *erk_arg, int nx, int np, int nf_max, int na_max)
48 {
49
50 int ns = erk_arg->rk_data->ns;
51
52 int nX = nx*(1+nf_max);
53
54 int steps = erk_arg->steps;
55
56 int size = 0;
57
58 size += 1*np*sizeof(double); // p
59 size += 1*nX*sizeof(double); // x_for
60 size += 1*ns*nX*sizeof(double); // K
61 size += 1*nX*sizeof(double); // x_tmp
62 if(na_max>0)
63 {
64 // size += 1*nX*(steps+1)*sizeof(double); // x_traj XXX
65 size += 1*nx*(ns*steps+1)*sizeof(double); // x_traj
66 // size += 1*ns*nX*steps*sizeof(double); // K
67 // size += 1*nX*ns*sizeof(double); // x_tmp
68 size += 1*nf_max*(steps+1)*sizeof(double); // l // XXX *na_max ???
69 size += 1*(nx+nf_max)*sizeof(double); // adj_in // XXX *na_max ???
70 size += 1*nf_max*ns*sizeof(double); // adj_tmp // XXX *na_max ???
71 }
72
73 return size;
74
75 }
76
77
78
d_create_erk_int(struct d_erk_arg * erk_arg,int nx,int np,int nf_max,int na_max,struct d_erk_workspace * ws,void * mem)79 void d_create_erk_int(struct d_erk_arg *erk_arg, int nx, int np, int nf_max, int na_max, struct d_erk_workspace *ws, void *mem)
80 {
81
82 ws->erk_arg = erk_arg;
83 ws->nx = nx;
84 ws->np = np;
85 ws->nf_max = nf_max;
86 ws->na_max = na_max;
87
88 int ns = erk_arg->rk_data->ns;
89
90 int nX = nx*(1+nf_max);
91
92 int steps = erk_arg->steps;
93
94 double *d_ptr = mem;
95
96 //
97 ws->p = d_ptr;
98 d_ptr += np;
99 //
100 ws->x_for = d_ptr;
101 d_ptr += nX;
102 //
103 ws->K = d_ptr;
104 d_ptr += ns*nX;
105 //
106 ws->x_tmp = d_ptr;
107 d_ptr += nX;
108 //
109 if(na_max>0)
110 {
111 //
112 // ws->x_for = d_ptr;
113 // d_ptr += nX*(steps+1);
114 //
115 ws->x_traj = d_ptr;
116 d_ptr += nx*(ns*steps+1);
117 //
118 // ws->K = d_ptr;
119 // d_ptr += ns*nX*steps;
120 //
121 // ws->x_tmp = d_ptr;
122 // d_ptr += nX*ns;
123 //
124 ws->l = d_ptr;
125 d_ptr += nf_max*(steps+1);
126 //
127 ws->adj_in = d_ptr;
128 d_ptr += nx+nf_max;
129 //
130 ws->adj_tmp = d_ptr;
131 d_ptr += nf_max*ns;
132 }
133
134
135 ws->memsize = d_memsize_erk_int(erk_arg, nx, np, nf_max, na_max);
136
137
138 char *c_ptr = (char *) d_ptr;
139
140
141 #if defined(RUNTIME_CHECKS)
142 if(c_ptr > ((char *) mem) + ws->memsize)
143 {
144 printf("\nCreate_erk_int: outsize memory bounds! %p %p\n\n", c_ptr, ((char *) mem) + ws->memsize);
145 exit(1);
146 }
147 #endif
148
149
150 return;
151
152 }
153
154
155
d_init_erk_int(int nf,int na,double * x0,double * p0,double * fs0,double * bs0,void (* vde_for)(int t,double * x,double * p,void * ode_args,double * xdot),void (* vde_adj)(int t,double * adj_in,void * ode_args,double * adj_out),void * ode_args,struct d_erk_workspace * ws)156 void d_init_erk_int(int nf, int na, double *x0, double *p0, double *fs0, double *bs0, void (*vde_for)(int t, double *x, double *p, void *ode_args, double *xdot), void (*vde_adj)(int t, double *adj_in, void *ode_args, double *adj_out), void *ode_args, struct d_erk_workspace *ws)
157 {
158
159 int ii;
160
161 ws->nf = nf;
162 ws->na = na;
163
164 int nx = ws->nx;
165 int np = ws->np;
166
167 int nX = nx*(1+nf);
168 int nA = np+nx; // XXX
169
170 int steps = ws->erk_arg->steps;
171
172 double *x_for = ws->x_for;
173 double *p = ws->p;
174 double *l = ws->l;
175
176 for(ii=0; ii<nx; ii++)
177 x_for[ii] = x0[ii];
178
179 for(ii=0; ii<nx*nf; ii++)
180 x_for[nx+ii] = fs0[ii];
181
182 for(ii=0; ii<np; ii++)
183 p[ii] = p0[ii];
184
185 if(na>0) // TODO what if na>1 !!!
186 {
187 for(ii=0; ii<np; ii++)
188 l[nA*steps+ii] = 0.0;
189 for(ii=0; ii<nx; ii++)
190 l[nA*steps+np+ii] = bs0[ii];
191 }
192
193 // ws->ode = ode;
194 ws->vde_for = vde_for;
195 ws->vde_adj = vde_adj;
196 ws->ode_args = ode_args;
197
198 // d_print_mat(1, nx*nf, x, 1);
199 // d_print_mat(1, np, p, 1);
200 // printf("\n%p %p\n", ode, ode_args);
201
202 return;
203
204 }
205
206
207
208 #if 0
209 void d_update_p_erk_int(double *p0, struct d_erk_workspace *ws)
210 {
211
212 int ii;
213
214 int np = ws->np;
215
216 double *p = ws->p;
217
218 for(ii=0; ii<np; ii++)
219 p[ii] = p0[ii];
220
221 return;
222
223 }
224 #endif
225
226
227
d_erk_int(struct d_erk_workspace * ws)228 void d_erk_int(struct d_erk_workspace *ws)
229 {
230
231 int steps = ws->erk_arg->steps;
232 double h = ws->erk_arg->h;
233
234 struct d_rk_data *rk_data = ws->erk_arg->rk_data;
235 int nx = ws->nx;
236 int np = ws->np;
237 int nf = ws->nf;
238 int na = ws->na;
239 double *K0 = ws->K;
240 double *x0 = ws->x_for;
241 double *x1 = ws->x_for;
242 double *x_traj = ws->x_traj;
243 double *p = ws->p;
244 double *x_tmp = ws->x_tmp;
245 double *adj_in = ws->adj_in;
246 double *adj_tmp = ws->adj_tmp;
247
248 double *l0, *l1;
249
250 int ns = rk_data->ns;
251 double *A_rk = rk_data->A_rk;
252 double *B_rk = rk_data->B_rk;
253 double *C_rk = rk_data->C_rk;
254
255 struct blasfeo_dvec sxt; // XXX
256 struct blasfeo_dvec sK; // XXX
257 sxt.pa = x_tmp; // XXX
258
259 int ii, jj, step, ss;
260 double t, a, b;
261
262 int nX = nx*(1+nf);
263 int nA = nx+np; // XXX
264
265 //printf("\nnf %d na %d nX %d nA %d\n", nf, na, nX, nA);
266 // forward sweep
267
268 // TODO no need to save the entire [x Su Sx] & sens, but only [x] & sens !!!
269
270 t = 0.0; // TODO plus time of multiple-shooting stage !!!
271 if(na>0)
272 {
273 x_traj = ws->x_traj;
274 for(ii=0; ii<nx; ii++)
275 x_traj[ii] = x0[ii];
276 x_traj += nx;
277 }
278 for(step=0; step<steps; step++)
279 {
280 // if(na>0)
281 // {
282 // x0 = ws->x_for + step*nX;
283 // x1 = ws->x_for + (step+1)*nX;
284 // for(ii=0; ii<nX; ii++)
285 // x1[ii] = x0[ii];
286 // K0 = ws->K + ns*step*nX;
287 // }
288 for(ss=0; ss<ns; ss++)
289 {
290 for(ii=0; ii<nX; ii++)
291 x_tmp[ii] = x0[ii];
292 for(ii=0; ii<ss; ii++)
293 {
294 a = A_rk[ss+ns*ii];
295 if(a!=0)
296 {
297 a *= h;
298 #if 0
299 sK.pa = K0+ii*nX; // XXX
300 blasfeo_daxpy(nX, a, &sK, 0, &sxt, 0, &sxt, 0); // XXX
301 #else
302 for(jj=0; jj<nX; jj++)
303 x_tmp[jj] += a*K0[jj+ii*(nX)];
304 #endif
305 }
306 }
307 if(na>0)
308 {
309 for(ii=0; ii<nx; ii++)
310 x_traj[ii] = x_tmp[ii];
311 x_traj += nx;
312 }
313 ws->vde_for(t+h*C_rk[ss], x_tmp, p, ws->ode_args, K0+ss*(nX));
314 }
315 for(ss=0; ss<ns; ss++)
316 {
317 b = h*B_rk[ss];
318 for(ii=0; ii<nX; ii++)
319 x1[ii] += b*K0[ii+ss*(nX)];
320 }
321 t += h;
322 }
323
324 // adjoint sweep
325
326 if(na>0)
327 {
328 x_traj = ws->x_traj + nx*ns*steps;
329 t = steps*h; // TODO plus time of multiple-shooting stage !!!
330 for(step=steps-1; step>=0; step--)
331 {
332 l0 = ws->l + step*nA;
333 l1 = ws->l + (step+1)*nA;
334 x0 = ws->x_for + step*nX;
335 K0 = ws->K + ns*step*nX; // XXX save all x insead !!!
336 // TODO save all x instead of K !!!
337 for(ss=ns-1; ss>=0; ss--)
338 {
339 // x
340 for(ii=0; ii<nx; ii++)
341 adj_in[0+ii] = x_traj[ii];
342 x_traj -= nx;
343 // for(ii=0; ii<nx; ii++)
344 // adj_in[0+ii] = x0[ii];
345 // for(ii=0; ii<ss; ii++)
346 // {
347 // a = A_rk[ss+ns*ii];
348 // if(a!=0)
349 // {
350 // a *= h;
351 // for(jj=0; jj<nx; jj++)
352 // adj_in[0+jj] += a*K0[jj+ii*(nX)];
353 // }
354 // }
355 // l
356 b = h*B_rk[ss];
357 for(ii=0; ii<nx; ii++)
358 adj_in[nx+ii] = b*l1[np+ii];
359 for(ii=ss+1; ii<ns; ii++)
360 {
361 a = A_rk[ii+ns*ss];
362 if(a!=0)
363 {
364 a *= h;
365 for(jj=0; jj<nx; jj++)
366 adj_in[nx+jj] += a*adj_tmp[np+jj+ii*nA];
367 }
368 }
369 // p
370 for(ii=0; ii<np; ii++)
371 adj_in[nx+nx+ii] = p[ii];
372 // adj_vde
373 ws->vde_adj(t+h*C_rk[ss], adj_in, ws->ode_args, adj_tmp+ss*nA);
374 }
375 // erk step
376 for(ii=0; ii<nA; ii++) // TODO move in the erk step !!!
377 l0[ii] = l1[ii];
378 for(ss=0; ss<ns; ss++)
379 for(ii=0; ii<nA; ii++)
380 l0[ii] += adj_tmp[ii+ss*nA];
381 t -= h;
382 }
383 }
384
385 return;
386
387 }
388
389
390
391
392