1 /**************************************************************************************************
2 *                                                                                                 *
3 * This file is part of HPIPM.                                                                     *
4 *                                                                                                 *
5 * HPIPM -- High-Performance Interior Point Method.                                                *
6 * Copyright (C) 2017-2018 by Gianluca Frison.                                                     *
7 * Developed at IMTEK (University of Freiburg) under the supervision of Moritz Diehl.              *
8 * All rights reserved.                                                                            *
9 *                                                                                                 *
10 * This program is free software: you can redistribute it and/or modify                            *
11 * it under the terms of the GNU General Public License as published by                            *
12 * the Free Software Foundation, either version 3 of the License, or                               *
13 * (at your option) any later version                                                              *.
14 *                                                                                                 *
15 * This program is distributed in the hope that it will be useful,                                 *
16 * but WITHOUT ANY WARRANTY; without even the implied warranty of                                  *
17 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the                                   *
18 * GNU General Public License for more details.                                                    *
19 *                                                                                                 *
20 * You should have received a copy of the GNU General Public License                               *
21 * along with this program.  If not, see <https://www.gnu.org/licenses/>.                          *
22 *                                                                                                 *
23 * The authors designate this particular file as subject to the "Classpath" exception              *
24 * as provided by the authors in the LICENSE file that accompained this code.                      *
25 *                                                                                                 *
26 * Author: Gianluca Frison, gianluca.frison (at) imtek.uni-freiburg.de                             *
27 *                                                                                                 *
28 **************************************************************************************************/
29 
30 
31 
32 #if defined(RUNTIME_CHECKS)
33 #include <stdlib.h>
34 #include <stdio.h>
35 #endif
36 
37 #include <blasfeo_target.h>
38 #include <blasfeo_common.h>
39 #include <blasfeo_d_aux.h>
40 #include <blasfeo_d_blas.h>
41 
42 #include "../include/hpipm_d_rk_int.h"
43 #include "../include/hpipm_d_erk_int.h"
44 
45 
46 
d_memsize_erk_int(struct d_erk_arg * erk_arg,int nx,int np,int nf_max,int na_max)47 int d_memsize_erk_int(struct d_erk_arg *erk_arg, int nx, int np, int nf_max, int na_max)
48 	{
49 
50 	int ns = erk_arg->rk_data->ns;
51 
52 	int nX = nx*(1+nf_max);
53 
54 	int steps = erk_arg->steps;
55 
56 	int size = 0;
57 
58 	size += 1*np*sizeof(double); // p
59 	size += 1*nX*sizeof(double); // x_for
60 	size += 1*ns*nX*sizeof(double); // K
61 	size += 1*nX*sizeof(double); // x_tmp
62 	if(na_max>0)
63 		{
64 //		size += 1*nX*(steps+1)*sizeof(double); // x_traj XXX
65 		size += 1*nx*(ns*steps+1)*sizeof(double); // x_traj
66 //		size += 1*ns*nX*steps*sizeof(double); // K
67 //		size += 1*nX*ns*sizeof(double); // x_tmp
68 		size += 1*nf_max*(steps+1)*sizeof(double); // l // XXX *na_max ???
69 		size += 1*(nx+nf_max)*sizeof(double); // adj_in // XXX *na_max ???
70 		size += 1*nf_max*ns*sizeof(double); // adj_tmp // XXX *na_max ???
71 		}
72 
73 	return size;
74 
75 	}
76 
77 
78 
d_create_erk_int(struct d_erk_arg * erk_arg,int nx,int np,int nf_max,int na_max,struct d_erk_workspace * ws,void * mem)79 void d_create_erk_int(struct d_erk_arg *erk_arg, int nx, int np, int nf_max, int na_max, struct d_erk_workspace *ws, void *mem)
80 	{
81 
82 	ws->erk_arg = erk_arg;
83 	ws->nx = nx;
84 	ws->np = np;
85 	ws->nf_max = nf_max;
86 	ws->na_max = na_max;
87 
88 	int ns = erk_arg->rk_data->ns;
89 
90 	int nX = nx*(1+nf_max);
91 
92 	int steps = erk_arg->steps;
93 
94 	double *d_ptr = mem;
95 
96 	//
97 	ws->p = d_ptr;
98 	d_ptr += np;
99 	//
100 	ws->x_for = d_ptr;
101 	d_ptr += nX;
102 	//
103 	ws->K = d_ptr;
104 	d_ptr += ns*nX;
105 	//
106 	ws->x_tmp = d_ptr;
107 	d_ptr += nX;
108 	//
109 	if(na_max>0)
110 		{
111 		//
112 //		ws->x_for = d_ptr;
113 //		d_ptr += nX*(steps+1);
114 		//
115 		ws->x_traj = d_ptr;
116 		d_ptr += nx*(ns*steps+1);
117 		//
118 //		ws->K = d_ptr;
119 //		d_ptr += ns*nX*steps;
120 		//
121 //		ws->x_tmp = d_ptr;
122 //		d_ptr += nX*ns;
123 		//
124 		ws->l = d_ptr;
125 		d_ptr += nf_max*(steps+1);
126 		//
127 		ws->adj_in = d_ptr;
128 		d_ptr += nx+nf_max;
129 		//
130 		ws->adj_tmp = d_ptr;
131 		d_ptr += nf_max*ns;
132 		}
133 
134 
135 	ws->memsize = d_memsize_erk_int(erk_arg, nx, np, nf_max, na_max);
136 
137 
138 	char *c_ptr = (char *) d_ptr;
139 
140 
141 #if defined(RUNTIME_CHECKS)
142 	if(c_ptr > ((char *) mem) + ws->memsize)
143 		{
144 		printf("\nCreate_erk_int: outsize memory bounds! %p %p\n\n", c_ptr, ((char *) mem) + ws->memsize);
145 		exit(1);
146 		}
147 #endif
148 
149 
150 	return;
151 
152 	}
153 
154 
155 
d_init_erk_int(int nf,int na,double * x0,double * p0,double * fs0,double * bs0,void (* vde_for)(int t,double * x,double * p,void * ode_args,double * xdot),void (* vde_adj)(int t,double * adj_in,void * ode_args,double * adj_out),void * ode_args,struct d_erk_workspace * ws)156 void d_init_erk_int(int nf, int na, double *x0, double *p0, double *fs0, double *bs0, void (*vde_for)(int t, double *x, double *p, void *ode_args, double *xdot), void (*vde_adj)(int t, double *adj_in, void *ode_args, double *adj_out), void *ode_args, struct d_erk_workspace *ws)
157 	{
158 
159 	int ii;
160 
161 	ws->nf = nf;
162 	ws->na = na;
163 
164 	int nx = ws->nx;
165 	int np = ws->np;
166 
167 	int nX = nx*(1+nf);
168 	int nA = np+nx; // XXX
169 
170 	int steps = ws->erk_arg->steps;
171 
172 	double *x_for = ws->x_for;
173 	double *p = ws->p;
174 	double *l = ws->l;
175 
176 	for(ii=0; ii<nx; ii++)
177 		x_for[ii] = x0[ii];
178 
179 	for(ii=0; ii<nx*nf; ii++)
180 		x_for[nx+ii] = fs0[ii];
181 
182 	for(ii=0; ii<np; ii++)
183 		p[ii] = p0[ii];
184 
185 	if(na>0) // TODO what if na>1 !!!
186 		{
187 		for(ii=0; ii<np; ii++)
188 			l[nA*steps+ii] = 0.0;
189 		for(ii=0; ii<nx; ii++)
190 			l[nA*steps+np+ii] = bs0[ii];
191 		}
192 
193 //	ws->ode = ode;
194 	ws->vde_for = vde_for;
195 	ws->vde_adj = vde_adj;
196 	ws->ode_args = ode_args;
197 
198 //	d_print_mat(1, nx*nf, x, 1);
199 //	d_print_mat(1, np, p, 1);
200 //	printf("\n%p %p\n", ode, ode_args);
201 
202 	return;
203 
204 	}
205 
206 
207 
208 #if 0
209 void d_update_p_erk_int(double *p0, struct d_erk_workspace *ws)
210 	{
211 
212 	int ii;
213 
214 	int np = ws->np;
215 
216 	double *p = ws->p;
217 
218 	for(ii=0; ii<np; ii++)
219 		p[ii] = p0[ii];
220 
221 	return;
222 
223 	}
224 #endif
225 
226 
227 
d_erk_int(struct d_erk_workspace * ws)228 void d_erk_int(struct d_erk_workspace *ws)
229 	{
230 
231 	int steps = ws->erk_arg->steps;
232 	double h = ws->erk_arg->h;
233 
234 	struct d_rk_data *rk_data = ws->erk_arg->rk_data;
235 	int nx = ws->nx;
236 	int np = ws->np;
237 	int nf = ws->nf;
238 	int na = ws->na;
239 	double *K0 = ws->K;
240 	double *x0 = ws->x_for;
241 	double *x1 = ws->x_for;
242 	double *x_traj = ws->x_traj;
243 	double *p = ws->p;
244 	double *x_tmp = ws->x_tmp;
245 	double *adj_in = ws->adj_in;
246 	double *adj_tmp = ws->adj_tmp;
247 
248 	double *l0, *l1;
249 
250 	int ns = rk_data->ns;
251 	double *A_rk = rk_data->A_rk;
252 	double *B_rk = rk_data->B_rk;
253 	double *C_rk = rk_data->C_rk;
254 
255 	struct blasfeo_dvec sxt; // XXX
256 	struct blasfeo_dvec sK; // XXX
257 	sxt.pa = x_tmp; // XXX
258 
259 	int ii, jj, step, ss;
260 	double t, a, b;
261 
262 	int nX = nx*(1+nf);
263 	int nA = nx+np; // XXX
264 
265 //printf("\nnf %d na %d nX %d nA %d\n", nf, na, nX, nA);
266 	// forward sweep
267 
268 	// TODO no need to save the entire [x Su Sx] & sens, but only [x] & sens !!!
269 
270 	t = 0.0; // TODO plus time of multiple-shooting stage !!!
271 	if(na>0)
272 		{
273 		x_traj = ws->x_traj;
274 		for(ii=0; ii<nx; ii++)
275 			x_traj[ii] = x0[ii];
276 		x_traj += nx;
277 		}
278 	for(step=0; step<steps; step++)
279 		{
280 //		if(na>0)
281 //			{
282 //			x0 = ws->x_for + step*nX;
283 //			x1 = ws->x_for + (step+1)*nX;
284 //			for(ii=0; ii<nX; ii++)
285 //				x1[ii] = x0[ii];
286 //			K0 = ws->K + ns*step*nX;
287 //			}
288 		for(ss=0; ss<ns; ss++)
289 			{
290 			for(ii=0; ii<nX; ii++)
291 				x_tmp[ii] = x0[ii];
292 			for(ii=0; ii<ss; ii++)
293 				{
294 				a = A_rk[ss+ns*ii];
295 				if(a!=0)
296 					{
297 					a *= h;
298 #if 0
299 					sK.pa = K0+ii*nX; // XXX
300 					blasfeo_daxpy(nX, a, &sK, 0, &sxt, 0, &sxt, 0); // XXX
301 #else
302 					for(jj=0; jj<nX; jj++)
303 						x_tmp[jj] += a*K0[jj+ii*(nX)];
304 #endif
305 					}
306 				}
307 			if(na>0)
308 				{
309 				for(ii=0; ii<nx; ii++)
310 					x_traj[ii] = x_tmp[ii];
311 				x_traj += nx;
312 				}
313 			ws->vde_for(t+h*C_rk[ss], x_tmp, p, ws->ode_args, K0+ss*(nX));
314 			}
315 		for(ss=0; ss<ns; ss++)
316 			{
317 			b = h*B_rk[ss];
318 			for(ii=0; ii<nX; ii++)
319 				x1[ii] += b*K0[ii+ss*(nX)];
320 			}
321 		t += h;
322 		}
323 
324 	// adjoint sweep
325 
326 	if(na>0)
327 		{
328 		x_traj = ws->x_traj + nx*ns*steps;
329 		t = steps*h; // TODO plus time of multiple-shooting stage !!!
330 		for(step=steps-1; step>=0; step--)
331 			{
332 			l0 = ws->l + step*nA;
333 			l1 = ws->l + (step+1)*nA;
334 			x0 = ws->x_for + step*nX;
335 			K0 = ws->K + ns*step*nX; // XXX save all x insead !!!
336 			// TODO save all x instead of K !!!
337 			for(ss=ns-1; ss>=0; ss--)
338 				{
339 				// x
340 				for(ii=0; ii<nx; ii++)
341 					adj_in[0+ii] = x_traj[ii];
342 				x_traj -= nx;
343 //				for(ii=0; ii<nx; ii++)
344 //					adj_in[0+ii] = x0[ii];
345 //				for(ii=0; ii<ss; ii++)
346 //					{
347 //					a = A_rk[ss+ns*ii];
348 //					if(a!=0)
349 //						{
350 //						a *= h;
351 //						for(jj=0; jj<nx; jj++)
352 //							adj_in[0+jj] += a*K0[jj+ii*(nX)];
353 //						}
354 //					}
355 				// l
356 				b = h*B_rk[ss];
357 				for(ii=0; ii<nx; ii++)
358 					adj_in[nx+ii] = b*l1[np+ii];
359 				for(ii=ss+1; ii<ns; ii++)
360 					{
361 					a = A_rk[ii+ns*ss];
362 					if(a!=0)
363 						{
364 						a *= h;
365 						for(jj=0; jj<nx; jj++)
366 							adj_in[nx+jj] += a*adj_tmp[np+jj+ii*nA];
367 						}
368 					}
369 				// p
370 				for(ii=0; ii<np; ii++)
371 					adj_in[nx+nx+ii] = p[ii];
372 				// adj_vde
373 				ws->vde_adj(t+h*C_rk[ss], adj_in, ws->ode_args, adj_tmp+ss*nA);
374 				}
375 			// erk step
376 			for(ii=0; ii<nA; ii++) // TODO move in the erk step !!!
377 				l0[ii] = l1[ii];
378 			for(ss=0; ss<ns; ss++)
379 				for(ii=0; ii<nA; ii++)
380 					l0[ii] += adj_tmp[ii+ss*nA];
381 			t -= h;
382 			}
383 		}
384 
385 	return;
386 
387 	}
388 
389 
390 
391 
392